110 if (elementType.
isF32())
113 return arith::TruncFOp::create(rewriter, loc, desType, f32);
115 return arith::ExtFOp::create(rewriter, loc, desType, f32);
116 llvm_unreachable(
"The only 32-bit float type is f32");
120ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
122 Type inType = op.getIn().getType();
123 auto inVecType = dyn_cast<VectorType>(inType);
125 if (inVecType.isScalable())
127 inType = inVecType.getElementType();
133 Value in = op.getIn();
135 VectorType extResType = VectorType::get(2, rewriter.
getF32Type());
137 Value asFloat = amdgpu::ExtPackedFp8Op::create(
143 int64_t numElements = inVecType.getNumElements();
145 Value zero = arith::ConstantOp::create(
147 VectorType outType = cast<VectorType>(op.getOut().getType());
149 if (inVecType.getShape().empty()) {
155 arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156 Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
163 outType.getElementType());
166 if (inVecType.getRank() > 1) {
168 inVecType.getElementType());
169 in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
172 for (
int64_t i = 0; i < numElements; i += 4) {
173 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174 Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
177 if (i +
j + 1 < numElements) {
178 Value asFloats = amdgpu::ExtPackedFp8Op::create(
179 rewriter, loc, extResType, inSlice,
j / 2);
180 Type desType = VectorType::get(2, outElemType);
182 result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
185 Value asFloat = amdgpu::ExtPackedFp8Op::create(
188 result = vector::InsertOp::create(rewriter, loc, asType,
result, i +
j);
193 if (inVecType.getRank() != outType.getRank()) {
194 result = vector::ShapeCastOp::create(rewriter, loc, outType,
result);
206 return arith::ExtFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
208 return arith::TruncFOp::create(rewriter, loc, rewriter.
getF32Type(), value);
209 llvm_unreachable(
"The only 32-bit float type is f32");
220 const llvm::fltSemantics &sourceSem =
222 const llvm::fltSemantics &targetSem =
223 cast<FloatType>(outElemType).getFloatSemantics();
225 APFloat
min = APFloat::getLargest(targetSem,
true);
226 APFloat
max = APFloat::getLargest(targetSem,
false);
227 bool ignoredLosesInfo =
false;
231 (
void)
min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232 (
void)
max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
238 rewriter, loc, sourceType,
239 APFloat::getInf(sourceSem,
false));
241 rewriter, loc, sourceType, APFloat::getInf(sourceSem,
true));
243 loc, arith::CmpFPredicate::OEQ, source, inf);
245 loc, arith::CmpFPredicate::OEQ, source, negInf);
247 loc, arith::CmpFPredicate::UNO, source, source);
248 Value isNonFinite = arith::OrIOp::create(
249 rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
252 Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
254 arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
256 arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
261TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
262 PatternRewriter &rewriter)
const {
264 if (op.getRoundingmodeAttr())
266 Type outType = op.getOut().getType();
267 auto outVecType = dyn_cast<VectorType>(outType);
269 if (outVecType.isScalable())
271 outType = outVecType.getElementType();
274 if (inType && inType.getWidth() <= 8 && saturateFP8)
281 Location loc = op.getLoc();
282 Value in = op.getIn();
285 in =
clampInput(rewriter, loc, outElemType, in);
286 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
287 VectorType truncResType = VectorType::get(4, outElemType);
289 Value asFloat =
castToF32(in, loc, rewriter);
290 Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291 rewriter, loc, truncResType, asFloat,
nullptr, 0,
293 Value
result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
298 int64_t numElements = outVecType.getNumElements();
299 Value zero = arith::ConstantOp::create(
300 rewriter, loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
301 if (outVecType.getShape().empty()) {
303 vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
306 arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307 Value
result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
308 ArrayRef<int64_t>{});
313 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
314 outVecType.getElementType());
317 if (inVectorTy.getRank() > 1) {
318 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
319 inVectorTy.getElementType());
320 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
323 for (int64_t i = 0; i < numElements; i += 4) {
324 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
325 Value thisResult =
nullptr;
326 for (int64_t j = 0; j < elemsThisOp; j += 2) {
327 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
328 Value asFloatA =
castToF32(elemA, loc, rewriter);
329 Value asFloatB =
nullptr;
330 if (j + 1 < elemsThisOp) {
331 Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
332 asFloatB =
castToF32(elemB, loc, rewriter);
334 thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335 rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
338 thisResult = vector::ExtractStridedSliceOp::create(
339 rewriter, loc, thisResult, 0, elemsThisOp, 1);
340 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
344 if (inVectorTy.getRank() != outVecType.getRank()) {
345 result = vector::ShapeCastOp::create(rewriter, loc, outVecType,
result);
352LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
353 arith::TruncFOp op, PatternRewriter &rewriter)
const {
354 Type outType = op.getOut().getType();
356 auto outVecType = dyn_cast<VectorType>(outType);
358 if (outVecType.isScalable())
360 outType = outVecType.getElementType();
365 Location loc = op.getLoc();
366 Value in = op.getIn();
368 VectorType truncResType = VectorType::get(2, outElemType);
369 auto inVectorTy = dyn_cast<VectorType>(in.
getType());
373 auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
375 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376 Value
result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
380 int64_t numElements = outVecType.getNumElements();
382 loc, outElemType, rewriter.
getFloatAttr(outElemType, 0.0));
384 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
386 if (inVectorTy.getRank() > 1) {
387 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
388 inVectorTy.getElementType());
389 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
394 for (int64_t i = 0; i < numElements; i += 2) {
395 int64_t elemsThisOp = std::min(numElements, i + 2) - i;
396 Value thisResult =
nullptr;
397 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398 Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.
getF32Type());
400 if (elemsThisOp == 2) {
401 elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
405 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
408 thisResult = vector::ExtractStridedSliceOp::create(
409 rewriter, loc, thisResult, 0, elemsThisOp, 1);
410 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
414 if (inVectorTy.getRank() != outVecType.getRank()) {
415 result = vector::ShapeCastOp::create(rewriter, loc, outVecType,
result);
424 Value current = value;
427 .Case<vector::ShapeCastOp>([¤t](
auto op) {
428 current = op.getSource();
431 .Case<vector::BroadcastOp>([¤t](
auto op) {
432 current = op.getSource();
445ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
446 PatternRewriter &rewriter)
const {
447 Location loc = op.getLoc();
448 constexpr int64_t opOutWidth = 2;
450 Value in = op.getIn();
451 Value scale = op.getScale();
452 Value out = op.getOut();
461 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
462 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
464 if (outVecType && outVecType.isScalable())
468 scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
470 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
472 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
474 VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
477 Value inCast = vector::BroadcastOp::create(rewriter, loc,
478 VectorType::get(1, inType), in);
480 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
481 rewriter, loc, extScaleResultType, inCast, scale, 0);
486 VectorType inVecType = cast<VectorType>(in.
getType());
488 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
490 ArrayRef<int64_t> inShape = inVecType.getShape();
491 SmallVector<int64_t> originalScaleShape;
492 if (origScaleVecType)
493 llvm::append_range(originalScaleShape, origScaleVecType.getShape());
495 originalScaleShape.insert(originalScaleShape.end(),
496 inShape.size() - originalScaleShape.size(), 1);
500 "failed to derive block size from broadcast or splat operation");
502 SmallVector<int64_t> ratio =
503 maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
507 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
510 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
512 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
513 SmallVector<int64_t> strides(offsets.size(), 1);
514 Value block = vector::ExtractStridedSliceOp::create(
515 rewriter, loc, in, offsets, ratio, strides);
516 VectorType block1DType = VectorType::get(blockSize, inType);
518 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
520 vector::ExtractOp::create(rewriter, loc, scale, offsets);
522 VectorType blockResultType = VectorType::get(blockSize, outType);
524 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
526 for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
528 i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
529 Value inSlice = vector::ExtractStridedSliceOp::create(
530 rewriter, loc, block1D, i, inSliceWidth, 1);
532 outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
533 j < inSliceWidth; j += outSliceWidth,
534 outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
536 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
537 rewriter, loc, extScaleResultType, inSlice, uniformScale,
539 if (outSliceWidth < opOutWidth) {
540 scaleExt = vector::ExtractStridedSliceOp::create(
541 rewriter, loc, scaleExt, 0, outSliceWidth, 1);
543 blockResult = vector::InsertStridedSliceOp::create(
544 rewriter, loc, scaleExt, blockResult, i + j, 1);
548 VectorType resultType = VectorType::get(ratio, outType);
550 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
551 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast,
result,
561ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
562 PatternRewriter &rewriter)
const {
563 Location loc = op.getLoc();
564 constexpr int64_t opInWidth = 2;
566 Value in = op.getIn();
567 Value scale = op.getScale();
568 Value out = op.getOut();
575 VectorType outVecType = dyn_cast<VectorType>(out.
getType());
576 VectorType scaleVecType = dyn_cast<VectorType>(scale.
getType());
577 if (outVecType && outVecType.isScalable())
581 scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
583 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
585 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
587 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
590 VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
593 Type inVecType = VectorType::get(1, inType);
594 Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
596 Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
597 rewriter, loc, truncScaleResultType, inCast, scale, 0,
604 VectorType inVecType = cast<VectorType>(in.
getType());
606 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.
getType());
608 ArrayRef<int64_t> inShape = inVecType.getShape();
609 SmallVector<int64_t> scaleShape;
610 if (origScaleVecType)
611 llvm::append_range(scaleShape, origScaleVecType.getShape());
613 scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
617 "failed to derive block size from broadcast or splat operation");
619 SmallVector<int64_t> ratio =
620 maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
625 rewriter.
createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
627 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
628 SmallVector<int64_t> strides(offsets.size(), 1);
629 Value block = vector::ExtractStridedSliceOp::create(
630 rewriter, loc, in, offsets, ratio, strides);
631 VectorType block1DType = VectorType::get(blockSize, inType);
633 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
635 vector::ExtractOp::create(rewriter, loc, scale, offsets);
637 VectorType blockResultType = VectorType::get(blockSize, outType);
639 rewriter.
createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
641 for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
642 i < blockSize; i += outSliceWidth,
643 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
646 if (outSliceWidth <= opInWidth) {
647 Value slice = vector::ExtractStridedSliceOp::create(
648 rewriter, loc, block1D, i, outSliceWidth, 1);
650 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
651 rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
654 scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
655 truncScaleResultType, zero);
657 inSliceWidth = std::min(opInWidth, outSliceWidth - j);
658 j < outSliceWidth; j += opInWidth,
659 inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
660 Value slice = vector::ExtractStridedSliceOp::create(
661 rewriter, loc, block1D, i + j, inSliceWidth, 1);
662 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
663 rewriter, loc, truncScaleResultType, slice, uniformScale,
664 j / opInWidth, scaleTrunc);
667 if (outSliceWidth != opOutWidth) {
668 scaleTrunc = vector::ExtractStridedSliceOp::create(
669 rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
671 blockResult = vector::InsertStridedSliceOp::create(
672 rewriter, loc, scaleTrunc, blockResult, i, 1);
675 VectorType resultType = VectorType::get(ratio, outType);
677 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
678 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast,
result,
689 bool saturateFP8Truncf,
bool allowPackedF16Rtz,
bool supportsScaledExtTrunc,
692 if (convertFP8Arithmetic) {
695 patterns.add<TruncFToFloat8RewritePattern>(
696 patterns.getContext(), saturateFP8Truncf, chipset, benefit);
698 if (allowPackedF16Rtz)
699 patterns.add<TruncfToFloat16RewritePattern>(
patterns.getContext(), benefit);
701 if (supportsScaledExtTrunc) {
707void ArithToAMDGPUConversionPass::runOnOperation() {
712 if (failed(maybeChipset)) {
713 emitError(UnknownLoc::get(ctx),
"Invalid chipset name: " + chipset);
714 return signalPassFailure();
717 bool convertFP8Arithmetic =
719 bool supportsScaledExtTrunc = *maybeChipset ==
kGfx950;
721 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
722 supportsScaledExtTrunc, *maybeChipset);
724 return signalPassFailure();