28 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29 #include "mlir/Conversion/Passes.h.inc"
40 struct ArithToAMDGPUConversionPass final
41 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42 using impl::ArithToAMDGPUConversionPassBase<
43 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
45 void runOnOperation()
override;
56 LogicalResult matchAndRewrite(arith::ExtFOp op,
60 struct TruncFToFloat8RewritePattern final :
OpRewritePattern<arith::TruncFOp> {
61 bool saturateFP8 =
false;
62 TruncFToFloat8RewritePattern(
MLIRContext *ctx,
bool saturateFP8,
65 saturateFP8(saturateFP8), chipset(chipset) {}
68 LogicalResult matchAndRewrite(arith::TruncFOp op,
72 struct TruncfToFloat16RewritePattern final
77 LogicalResult matchAndRewrite(arith::TruncFOp op,
81 struct ScalingExtFRewritePattern final
85 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
89 struct ScalingTruncFRewritePattern final
93 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
101 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
103 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
110 if (elementType.
isF32())
113 return arith::TruncFOp::create(rewriter, loc, desType, f32);
115 return arith::ExtFOp::create(rewriter, loc, desType, f32);
116 llvm_unreachable(
"The only 32-bit float type is f32");
120 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
122 Type inType = op.getIn().getType();
123 auto inVecType = dyn_cast<VectorType>(inType);
125 if (inVecType.isScalable())
127 inType = inVecType.getElementType();
133 Value in = op.getIn();
137 Value asFloat = amdgpu::ExtPackedFp8Op::create(
143 int64_t numElements = inVecType.getNumElements();
145 Value zero = arith::ConstantOp::create(
146 rewriter, loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
147 VectorType outType = cast<VectorType>(op.getOut().getType());
149 if (inVecType.getShape().empty()) {
151 rewriter.
createOrFold<vector::BroadcastOp>(loc, outType, zero);
155 arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156 Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
163 outType.getElementType());
166 if (inVecType.getRank() > 1) {
168 inVecType.getElementType());
169 in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
172 for (int64_t i = 0; i < numElements; i += 4) {
173 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
174 Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
176 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
177 if (i +
j + 1 < numElements) {
178 Value asFloats = amdgpu::ExtPackedFp8Op::create(
179 rewriter, loc, extResType, inSlice,
j / 2);
182 result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
185 Value asFloat = amdgpu::ExtPackedFp8Op::create(
186 rewriter, loc, rewriter.
getF32Type(), inSlice,
j / 2 * 2);
188 result = vector::InsertOp::create(rewriter, loc, asType, result, i +
j);
193 if (inVecType.getRank() != outType.getRank()) {
194 result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
206 return arith::ExtFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
208 return arith::TruncFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
209 llvm_unreachable(
"The only 32-bit float type is f32");
220 const llvm::fltSemantics &sourceSem =
222 const llvm::fltSemantics &targetSem =
223 cast<FloatType>(outElemType).getFloatSemantics();
225 APFloat
min = APFloat::getLargest(targetSem,
true);
226 APFloat
max = APFloat::getLargest(targetSem,
false);
227 bool ignoredLosesInfo =
false;
231 (void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232 (void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
238 rewriter, loc, sourceType,
239 APFloat::getInf(sourceSem,
false));
241 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
243 loc, arith::CmpFPredicate::OEQ, source, inf);
245 loc, arith::CmpFPredicate::OEQ, source, negInf);
247 loc, arith::CmpFPredicate::UNO, source, source);
248 Value isNonFinite = arith::OrIOp::create(
249 rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
252 Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
254 arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
256 arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
261 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
264 if (op.getRoundingmodeAttr())
266 Type outType = op.getOut().getType();
267 auto outVecType = dyn_cast<VectorType>(outType);
269 if (outVecType.isScalable())
271 outType = outVecType.getElementType();
274 if (inType && inType.getWidth() <= 8 && saturateFP8)
282 Value in = op.getIn();
285 in =
clampInput(rewriter, loc, outElemType, in);
286 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
290 Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291 rewriter, loc, truncResType, asFloat,
nullptr, 0,
293 Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
298 int64_t numElements = outVecType.getNumElements();
299 Value zero = arith::ConstantOp::create(
300 rewriter, loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
301 if (outVecType.getShape().empty()) {
306 arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307 Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
314 outVecType.getElementType());
317 if (inVectorTy.getRank() > 1) {
319 inVectorTy.getElementType());
320 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
323 for (int64_t i = 0; i < numElements; i += 4) {
324 int64_t elemsThisOp =
std::min(numElements, i + 4) - i;
325 Value thisResult =
nullptr;
326 for (int64_t
j = 0;
j < elemsThisOp;
j += 2) {
327 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i +
j);
329 Value asFloatB =
nullptr;
330 if (
j + 1 < elemsThisOp) {
331 Value elemB = vector::ExtractOp::create(rewriter, loc, in, i +
j + 1);
332 asFloatB =
castToF32(elemB, loc, rewriter);
334 thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335 rewriter, loc, truncResType, asFloatA, asFloatB,
j / 2, thisResult);
338 thisResult = vector::ExtractStridedSliceOp::create(
339 rewriter, loc, thisResult, 0, elemsThisOp, 1);
340 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
344 if (inVectorTy.getRank() != outVecType.getRank()) {
345 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
352 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
354 Type outType = op.getOut().getType();
356 auto outVecType = dyn_cast<VectorType>(outType);
358 if (outVecType.isScalable())
360 outType = outVecType.getElementType();
366 Value in = op.getIn();
369 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
373 auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
375 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376 Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
380 int64_t numElements = outVecType.getNumElements();
382 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
384 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
386 if (inVectorTy.getRank() > 1) {
388 inVectorTy.getElementType());
389 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
394 for (int64_t i = 0; i < numElements; i += 2) {
395 int64_t elemsThisOp =
std::min(numElements, i + 2) - i;
396 Value thisResult =
nullptr;
397 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398 Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
400 if (elemsThisOp == 2) {
401 elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
405 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
408 thisResult = vector::ExtractStridedSliceOp::create(
409 rewriter, loc, thisResult, 0, elemsThisOp, 1);
410 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
414 if (inVectorTy.getRank() != outVecType.getRank()) {
415 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
424 Value current = value;
427 .Case<vector::ShapeCastOp>([¤t](
auto op) {
428 current = op.getSource();
431 .Case<vector::BroadcastOp>([¤t](
auto op) {
432 current = op.getSource();
435 .Case<vector::SplatOp>([¤t](
auto op) {
436 current = op.getInput();
439 .Default([](
Operation *) {
return false; });
449 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
452 constexpr int64_t opOutWidth = 2;
454 Value in = op.getIn();
455 Value scale = op.getScale();
456 Value out = op.getOut();
465 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
466 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
468 if (outVecType && outVecType.isScalable())
474 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
476 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
481 Value inCast = vector::BroadcastOp::create(rewriter, loc,
484 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
485 rewriter, loc, extScaleResultType, inCast, scale, 0);
490 VectorType inVecType = cast<VectorType>(in.
getType());
492 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
496 if (origScaleVecType)
497 llvm::append_range(originalScaleShape, origScaleVecType.getShape());
499 originalScaleShape.insert(originalScaleShape.end(),
500 inShape.size() - originalScaleShape.size(), 1);
504 "failed to derive block size from broadcast or splat operation");
511 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
514 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
518 Value block = vector::ExtractStridedSliceOp::create(
519 rewriter, loc, in, offsets, ratio, strides);
522 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
524 vector::ExtractOp::create(rewriter, loc, scale, offsets);
528 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
530 for (int64_t i = 0, inSliceWidth =
std::min(opInWidth, blockSize - i);
532 i += inSliceWidth, inSliceWidth =
std::min(opInWidth, blockSize - i)) {
533 Value inSlice = vector::ExtractStridedSliceOp::create(
534 rewriter, loc, block1D, i, inSliceWidth, 1);
536 outSliceWidth =
std::min(opOutWidth, inSliceWidth -
j);
537 j < inSliceWidth;
j += outSliceWidth,
538 outSliceWidth =
std::min(opOutWidth, inSliceWidth -
j)) {
540 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
541 rewriter, loc, extScaleResultType, inSlice, uniformScale,
543 if (outSliceWidth < opOutWidth) {
544 scaleExt = vector::ExtractStridedSliceOp::create(
545 rewriter, loc, scaleExt, 0, outSliceWidth, 1);
547 blockResult = vector::InsertStridedSliceOp::create(
548 rewriter, loc, scaleExt, blockResult, i +
j, 1);
554 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
555 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
565 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
568 constexpr int64_t opInWidth = 2;
570 Value in = op.getIn();
571 Value scale = op.getScale();
572 Value out = op.getOut();
579 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
580 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
581 if (outVecType && outVecType.isScalable())
587 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
589 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
591 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
594 VectorType truncScaleResultType =
VectorType::get(opOutWidth, outType);
598 Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
600 Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
601 rewriter, loc, truncScaleResultType, inCast, scale, 0,
608 VectorType inVecType = cast<VectorType>(in.
getType());
610 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
614 if (origScaleVecType)
615 llvm::append_range(scaleShape, origScaleVecType.getShape());
617 scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
621 "failed to derive block size from broadcast or splat operation");
629 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
633 Value block = vector::ExtractStridedSliceOp::create(
634 rewriter, loc, in, offsets, ratio, strides);
637 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
639 vector::ExtractOp::create(rewriter, loc, scale, offsets);
643 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
645 for (int64_t i = 0, outSliceWidth =
std::min(opOutWidth, blockSize - i);
646 i < blockSize; i += outSliceWidth,
647 outSliceWidth =
std::min(opOutWidth, blockSize - i)) {
650 if (outSliceWidth <= opInWidth) {
651 Value slice = vector::ExtractStridedSliceOp::create(
652 rewriter, loc, block1D, i, outSliceWidth, 1);
654 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
655 rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
658 scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
659 truncScaleResultType, zero);
661 inSliceWidth =
std::min(opInWidth, outSliceWidth -
j);
662 j < outSliceWidth;
j += opInWidth,
663 inSliceWidth =
std::min(opInWidth, outSliceWidth -
j)) {
664 Value slice = vector::ExtractStridedSliceOp::create(
665 rewriter, loc, block1D, i +
j, inSliceWidth, 1);
666 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
667 rewriter, loc, truncScaleResultType, slice, uniformScale,
668 j / opInWidth, scaleTrunc);
671 if (outSliceWidth != opOutWidth) {
672 scaleTrunc = vector::ExtractStridedSliceOp::create(
673 rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
675 blockResult = vector::InsertStridedSliceOp::create(
676 rewriter, loc, scaleTrunc, blockResult, i, 1);
681 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
682 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
693 bool saturateFP8Truncf,
bool allowPackedF16Rtz,
Chipset chipset,
696 if (convertFP8Arithmetic) {
699 patterns.add<TruncFToFloat8RewritePattern>(
700 patterns.getContext(), saturateFP8Truncf, chipset, benefit);
702 if (allowPackedF16Rtz)
703 patterns.add<TruncfToFloat16RewritePattern>(
patterns.getContext(), benefit);
711 void ArithToAMDGPUConversionPass::runOnOperation() {
716 if (
failed(maybeChipset)) {
718 return signalPassFailure();
721 bool convertFP8Arithmetic =
724 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
727 return signalPassFailure();
constexpr Chipset kGfx942
constexpr Chipset kGfx950
static Value getOriginalVectorValue(Value value)
Get the broadcasted / splatted value for a chain of ops.
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatAttr getFloatAttr(Type type, double value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation is the basic unit of execution within MLIR.
MLIRContext * getContext()
Return the context this operation is associated with.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
bool hasOcpFp8(const Chipset &chipset)
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.