MLIR 22.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1//===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
24#include "mlir/Pass/Pass.h"
26
27namespace mlir {
28#define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29#include "mlir/Conversion/Passes.h.inc"
30} // namespace mlir
31
32using namespace mlir;
33using namespace mlir::amdgpu;
34
35namespace {
36// Define commonly used chipsets versions for convenience.
37constexpr Chipset kGfx942 = Chipset(9, 4, 2);
38constexpr Chipset kGfx950 = Chipset(9, 5, 0);
39
40struct ArithToAMDGPUConversionPass final
41 : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42 using impl::ArithToAMDGPUConversionPassBase<
43 ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
44
45 void runOnOperation() override;
46};
47
48struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
49 using Base::Base;
50
51 Chipset chipset;
52 ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
53 PatternBenefit benefit)
54 : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
55
56 LogicalResult matchAndRewrite(arith::ExtFOp op,
57 PatternRewriter &rewriter) const override;
58};
59
60struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
61 bool saturateFP8 = false;
62 TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
63 Chipset chipset, PatternBenefit benefit)
64 : OpRewritePattern::OpRewritePattern(ctx, benefit),
65 saturateFP8(saturateFP8), chipset(chipset) {}
66 Chipset chipset;
67
68 LogicalResult matchAndRewrite(arith::TruncFOp op,
69 PatternRewriter &rewriter) const override;
70};
71
72struct TruncfToFloat16RewritePattern final
73 : public OpRewritePattern<arith::TruncFOp> {
74
75 using Base::Base;
76
77 LogicalResult matchAndRewrite(arith::TruncFOp op,
78 PatternRewriter &rewriter) const override;
79};
80
81struct ScalingExtFRewritePattern final
82 : OpRewritePattern<arith::ScalingExtFOp> {
83 using Base::Base;
84
85 LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
86 PatternRewriter &rewriter) const override;
87};
88
89struct ScalingTruncFRewritePattern final
90 : OpRewritePattern<arith::ScalingTruncFOp> {
91 using Base::Base;
92
93 LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
94 PatternRewriter &rewriter) const override;
95};
96
97} // end namespace
98
99static bool isSupportedF8(Type elementType, Chipset chipset) {
100 if (chipset == kGfx942)
101 return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
102 if (hasOcpFp8(chipset))
103 return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
104 return false;
105}
106
107static Value castF32To(Type desType, Value f32, Location loc,
108 PatternRewriter &rewriter) {
109 Type elementType = getElementTypeOrSelf(desType);
110 if (elementType.isF32())
111 return f32;
112 if (elementType.getIntOrFloatBitWidth() < 32)
113 return arith::TruncFOp::create(rewriter, loc, desType, f32);
114 if (elementType.getIntOrFloatBitWidth() > 32)
115 return arith::ExtFOp::create(rewriter, loc, desType, f32);
116 llvm_unreachable("The only 32-bit float type is f32");
118
119LogicalResult
120ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
121 PatternRewriter &rewriter) const {
122 Type inType = op.getIn().getType();
123 auto inVecType = dyn_cast<VectorType>(inType);
124 if (inVecType) {
125 if (inVecType.isScalable())
126 return failure();
127 inType = inVecType.getElementType();
128 }
129 if (!isSupportedF8(inType, chipset))
130 return failure();
131
132 Location loc = op.getLoc();
133 Value in = op.getIn();
134 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
135 VectorType extResType = VectorType::get(2, rewriter.getF32Type());
136 if (!inVecType) {
137 Value asFloat = amdgpu::ExtPackedFp8Op::create(
138 rewriter, loc, rewriter.getF32Type(), in, 0);
139 Value result = castF32To(outElemType, asFloat, loc, rewriter);
140 rewriter.replaceOp(op, result);
141 return success();
142 }
143 int64_t numElements = inVecType.getNumElements();
144
145 Value zero = arith::ConstantOp::create(
146 rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
147 VectorType outType = cast<VectorType>(op.getOut().getType());
148
149 if (inVecType.getShape().empty()) {
150 Value zerodSplat =
151 rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
152 Value scalarIn =
153 vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
154 Value scalarExt =
155 arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156 Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
157 zerodSplat, ArrayRef<int64_t>{});
158 rewriter.replaceOp(op, result);
159 return success();
160 }
162 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
163 outType.getElementType());
164 Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
165
166 if (inVecType.getRank() > 1) {
167 inVecType = VectorType::get(SmallVector<int64_t>{numElements},
168 inVecType.getElementType());
169 in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
170 }
171
172 for (int64_t i = 0; i < numElements; i += 4) {
173 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174 Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
175 elemsThisOp, 1);
176 for (int64_t j = 0; j < elemsThisOp; j += 2) {
177 if (i + j + 1 < numElements) { // Convert two 8-bit elements
178 Value asFloats = amdgpu::ExtPackedFp8Op::create(
179 rewriter, loc, extResType, inSlice, j / 2);
180 Type desType = VectorType::get(2, outElemType);
181 Value asType = castF32To(desType, asFloats, loc, rewriter);
182 result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
183 result, i + j, 1);
184 } else { // Convert a 8-bit element
185 Value asFloat = amdgpu::ExtPackedFp8Op::create(
186 rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
187 Value asType = castF32To(outElemType, asFloat, loc, rewriter);
188 result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
189 }
190 }
191 }
192
193 if (inVecType.getRank() != outType.getRank()) {
194 result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
195 }
196
197 rewriter.replaceOp(op, result);
198 return success();
199}
200
201static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
202 Type type = value.getType();
203 if (type.isF32())
204 return value;
205 if (type.getIntOrFloatBitWidth() < 32)
206 return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
207 if (type.getIntOrFloatBitWidth() > 32)
208 return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
209 llvm_unreachable("The only 32-bit float type is f32");
210}
211
212// If `in` is a finite value, clamp it between the maximum and minimum values
213// of `outElemType` so that subsequent conversion instructions don't
214// overflow those out-of-range values to NaN. These semantics are commonly
215// used in machine-learning contexts where failure to clamp would lead to
216// excessive NaN production.
218 Type outElemType, Value source) {
219 Type sourceType = source.getType();
220 const llvm::fltSemantics &sourceSem =
221 cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
222 const llvm::fltSemantics &targetSem =
223 cast<FloatType>(outElemType).getFloatSemantics();
224
225 APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
226 APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
227 bool ignoredLosesInfo = false;
228 // We can ignore conversion failures here because this conversion promotes
229 // from a smaller type to a larger one - ex. there can be no loss of precision
230 // when casting fp8 to f16.
231 (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232 (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
233
234 Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
235 Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
236
238 rewriter, loc, sourceType,
239 APFloat::getInf(sourceSem, /*Negative=*/false));
241 rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
242 Value isInf = rewriter.createOrFold<arith::CmpFOp>(
243 loc, arith::CmpFPredicate::OEQ, source, inf);
244 Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
245 loc, arith::CmpFPredicate::OEQ, source, negInf);
246 Value isNan = rewriter.createOrFold<arith::CmpFOp>(
247 loc, arith::CmpFPredicate::UNO, source, source);
248 Value isNonFinite = arith::OrIOp::create(
249 rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
250 isNan);
251
252 Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
253 Value clamped =
254 arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
255 Value res =
256 arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
257 return res;
258}
259
260LogicalResult
261TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
262 PatternRewriter &rewriter) const {
263 // Only supporting default rounding mode as of now.
264 if (op.getRoundingmodeAttr())
265 return failure();
266 Type outType = op.getOut().getType();
267 auto outVecType = dyn_cast<VectorType>(outType);
268 if (outVecType) {
269 if (outVecType.isScalable())
270 return failure();
271 outType = outVecType.getElementType();
272 }
273 auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
274 if (inType && inType.getWidth() <= 8 && saturateFP8)
275 // Conversion between 8-bit floats is not supported with truncation enabled.
276 return failure();
277
278 if (!isSupportedF8(outType, chipset))
279 return failure();
280
281 Location loc = op.getLoc();
282 Value in = op.getIn();
283 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
284 if (saturateFP8)
285 in = clampInput(rewriter, loc, outElemType, in);
286 auto inVectorTy = dyn_cast<VectorType>(in.getType());
287 VectorType truncResType = VectorType::get(4, outElemType);
288 if (!inVectorTy) {
289 Value asFloat = castToF32(in, loc, rewriter);
290 Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291 rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
292 /*existing=*/nullptr);
293 Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
294 rewriter.replaceOp(op, result);
295 return success();
296 }
297
298 int64_t numElements = outVecType.getNumElements();
299 Value zero = arith::ConstantOp::create(
300 rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
301 if (outVecType.getShape().empty()) {
302 Value scalarIn =
303 vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
304 // Recurse to send the 0-D vector case to the 1-D vector case
305 Value scalarTrunc =
306 arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307 Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
308 ArrayRef<int64_t>{});
309 rewriter.replaceOp(op, result);
310 return success();
311 }
312
313 VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
314 outVecType.getElementType());
315 Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
316
317 if (inVectorTy.getRank() > 1) {
318 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
319 inVectorTy.getElementType());
320 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
321 }
322
323 for (int64_t i = 0; i < numElements; i += 4) {
324 int64_t elemsThisOp = std::min(numElements, i + 4) - i;
325 Value thisResult = nullptr;
326 for (int64_t j = 0; j < elemsThisOp; j += 2) {
327 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
328 Value asFloatA = castToF32(elemA, loc, rewriter);
329 Value asFloatB = nullptr;
330 if (j + 1 < elemsThisOp) {
331 Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
332 asFloatB = castToF32(elemB, loc, rewriter);
333 }
334 thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335 rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
336 }
337 if (elemsThisOp < 4)
338 thisResult = vector::ExtractStridedSliceOp::create(
339 rewriter, loc, thisResult, 0, elemsThisOp, 1);
340 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
341 result, i, 1);
342 }
343
344 if (inVectorTy.getRank() != outVecType.getRank()) {
345 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
346 }
347
348 rewriter.replaceOp(op, result);
349 return success();
350}
351
352LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
353 arith::TruncFOp op, PatternRewriter &rewriter) const {
354 Type outType = op.getOut().getType();
355 Type inputType = getElementTypeOrSelf(op.getIn());
356 auto outVecType = dyn_cast<VectorType>(outType);
357 if (outVecType) {
358 if (outVecType.isScalable())
359 return failure();
360 outType = outVecType.getElementType();
361 }
362 if (!(outType.isF16() && inputType.isF32()))
363 return failure();
364
365 Location loc = op.getLoc();
366 Value in = op.getIn();
367 Type outElemType = getElementTypeOrSelf(op.getOut().getType());
368 VectorType truncResType = VectorType::get(2, outElemType);
369 auto inVectorTy = dyn_cast<VectorType>(in.getType());
370
371 // Handle the case where input type is not a vector type
372 if (!inVectorTy) {
373 auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
374 Value asF16s =
375 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376 Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
377 rewriter.replaceOp(op, result);
378 return success();
379 }
380 int64_t numElements = outVecType.getNumElements();
381 Value zero = rewriter.createOrFold<arith::ConstantOp>(
382 loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
383 Value result =
384 rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
385
386 if (inVectorTy.getRank() > 1) {
387 inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
388 inVectorTy.getElementType());
389 in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
390 }
391
392 // Handle the vector case. We also handle the (uncommon) case where the vector
393 // length is odd
394 for (int64_t i = 0; i < numElements; i += 2) {
395 int64_t elemsThisOp = std::min(numElements, i + 2) - i;
396 Value thisResult = nullptr;
397 Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398 Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
399
400 if (elemsThisOp == 2) {
401 elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
402 }
403
404 thisResult =
405 ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
406 // Place back the truncated result into the possibly larger vector. If we
407 // are operating on a size 2 vector, these operations should be folded away
408 thisResult = vector::ExtractStridedSliceOp::create(
409 rewriter, loc, thisResult, 0, elemsThisOp, 1);
410 result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
411 result, i, 1);
412 }
413
414 if (inVectorTy.getRank() != outVecType.getRank()) {
415 result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
416 }
417
418 rewriter.replaceOp(op, result);
419 return success();
420}
421
422/// Get the broadcasted / splatted value for a chain of ops.
424 Value current = value;
425 while (Operation *definingOp = current.getDefiningOp()) {
426 bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
427 .Case<vector::ShapeCastOp>([&current](auto op) {
428 current = op.getSource();
429 return true;
430 })
431 .Case<vector::BroadcastOp>([&current](auto op) {
432 current = op.getSource();
433 return false;
434 })
435 .Default(false);
436
437 if (!skipOp) {
438 break;
439 }
440 }
441 return current;
442}
443
444LogicalResult
445ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
446 PatternRewriter &rewriter) const {
447 Location loc = op.getLoc();
448 constexpr int64_t opOutWidth = 2;
449
450 Value in = op.getIn();
451 Value scale = op.getScale();
452 Value out = op.getOut();
453
454 Type f32 = rewriter.getF32Type();
455 Type inType = getElementTypeOrSelf(in);
456 Type scaleType = getElementTypeOrSelf(scale);
457 Type outType = getElementTypeOrSelf(out);
458
459 int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
460
461 VectorType outVecType = dyn_cast<VectorType>(out.getType());
462 VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
463
464 if (outVecType && outVecType.isScalable())
465 return failure();
466
467 Type scaleF32Type =
468 scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
469 if (scaleType.getIntOrFloatBitWidth() < 32)
470 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
471 else if (scaleType.getIntOrFloatBitWidth() > 32)
472 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
473
474 VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
475
476 if (!outVecType) {
477 Value inCast = vector::BroadcastOp::create(rewriter, loc,
478 VectorType::get(1, inType), in);
479 // TODO: replace this with non-packed ScaledExtOp
480 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
481 rewriter, loc, extScaleResultType, inCast, scale, 0);
482 scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
483 return success();
484 }
485
486 VectorType inVecType = cast<VectorType>(in.getType());
487 Value origScale = getOriginalVectorValue(op.getScale());
488 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
489
490 ArrayRef<int64_t> inShape = inVecType.getShape();
491 SmallVector<int64_t> originalScaleShape;
492 if (origScaleVecType)
493 llvm::append_range(originalScaleShape, origScaleVecType.getShape());
494
495 originalScaleShape.insert(originalScaleShape.end(),
496 inShape.size() - originalScaleShape.size(), 1);
497
498 auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
499 assert(maybeRatio &&
500 "failed to derive block size from broadcast or splat operation");
501
502 SmallVector<int64_t> ratio =
503 maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
504
505 int64_t blockSize = computeProduct(ratio);
506
507 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
508 rewriter.getFloatAttr(outType, 0.0));
509 Value result =
510 rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
511
512 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
513 SmallVector<int64_t> strides(offsets.size(), 1);
514 Value block = vector::ExtractStridedSliceOp::create(
515 rewriter, loc, in, offsets, ratio, strides);
516 VectorType block1DType = VectorType::get(blockSize, inType);
517 Value block1D =
518 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
519 Value uniformScale =
520 vector::ExtractOp::create(rewriter, loc, scale, offsets);
521
522 VectorType blockResultType = VectorType::get(blockSize, outType);
523 Value blockResult =
524 rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
525
526 for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
527 i < blockSize;
528 i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
529 Value inSlice = vector::ExtractStridedSliceOp::create(
530 rewriter, loc, block1D, i, inSliceWidth, 1);
531 for (int64_t j = 0,
532 outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
533 j < inSliceWidth; j += outSliceWidth,
534 outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
535 // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
536 Value scaleExt = amdgpu::ScaledExtPackedOp::create(
537 rewriter, loc, extScaleResultType, inSlice, uniformScale,
538 j / opOutWidth);
539 if (outSliceWidth < opOutWidth) {
540 scaleExt = vector::ExtractStridedSliceOp::create(
541 rewriter, loc, scaleExt, 0, outSliceWidth, 1);
542 }
543 blockResult = vector::InsertStridedSliceOp::create(
544 rewriter, loc, scaleExt, blockResult, i + j, 1);
545 }
546 }
547
548 VectorType resultType = VectorType::get(ratio, outType);
549 Value cast =
550 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
551 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
552 offsets, strides);
553 }
554
555 rewriter.replaceOp(op, result);
556
557 return success();
558}
559
560LogicalResult
561ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
562 PatternRewriter &rewriter) const {
563 Location loc = op.getLoc();
564 constexpr int64_t opInWidth = 2;
565
566 Value in = op.getIn();
567 Value scale = op.getScale();
568 Value out = op.getOut();
569
570 Type f32 = rewriter.getF32Type();
571 Type inType = getElementTypeOrSelf(in);
572 Type scaleType = getElementTypeOrSelf(scale);
573 Type outType = getElementTypeOrSelf(out);
574
575 VectorType outVecType = dyn_cast<VectorType>(out.getType());
576 VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
577 if (outVecType && outVecType.isScalable())
578 return failure();
579
580 Type scaleF32Type =
581 scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
582 if (scaleType.getIntOrFloatBitWidth() < 32)
583 scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
584 else if (scaleType.getIntOrFloatBitWidth() > 32)
585 scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
586
587 Value zero = arith::ConstantOp::create(rewriter, loc, outType,
588 rewriter.getFloatAttr(outType, 0.0));
589 int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
590 VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
591
592 if (!outVecType) {
593 Type inVecType = VectorType::get(1, inType);
594 Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
595 // TODO: replace this with non-packed ScaledTruncOp
596 Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
597 rewriter, loc, truncScaleResultType, inCast, scale, 0,
598 /*existing=*/nullptr);
599 scaleTrunc =
600 rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
601 return success();
602 }
603
604 VectorType inVecType = cast<VectorType>(in.getType());
605 Value origScale = getOriginalVectorValue(op.getScale());
606 VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
607
608 ArrayRef<int64_t> inShape = inVecType.getShape();
609 SmallVector<int64_t> scaleShape;
610 if (origScaleVecType)
611 llvm::append_range(scaleShape, origScaleVecType.getShape());
612
613 scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
614
615 auto maybeRatio = computeShapeRatio(inShape, scaleShape);
616 assert(maybeRatio &&
617 "failed to derive block size from broadcast or splat operation");
618
619 SmallVector<int64_t> ratio =
620 maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
621
622 int64_t blockSize = computeProduct(ratio);
623
624 Value result =
625 rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
626
627 for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
628 SmallVector<int64_t> strides(offsets.size(), 1);
629 Value block = vector::ExtractStridedSliceOp::create(
630 rewriter, loc, in, offsets, ratio, strides);
631 VectorType block1DType = VectorType::get(blockSize, inType);
632 Value block1D =
633 vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
634 Value uniformScale =
635 vector::ExtractOp::create(rewriter, loc, scale, offsets);
636
637 VectorType blockResultType = VectorType::get(blockSize, outType);
638 Value blockResult =
639 rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
640
641 for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
642 i < blockSize; i += outSliceWidth,
643 outSliceWidth = std::min(opOutWidth, blockSize - i)) {
644 Value scaleTrunc;
645 // Case where <= 2 elements are being truncated.
646 if (outSliceWidth <= opInWidth) {
647 Value slice = vector::ExtractStridedSliceOp::create(
648 rewriter, loc, block1D, i, outSliceWidth, 1);
649 // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
650 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
651 rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
652 /*existing=*/nullptr);
653 } else {
654 scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
655 truncScaleResultType, zero);
656 for (int64_t j = 0,
657 inSliceWidth = std::min(opInWidth, outSliceWidth - j);
658 j < outSliceWidth; j += opInWidth,
659 inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
660 Value slice = vector::ExtractStridedSliceOp::create(
661 rewriter, loc, block1D, i + j, inSliceWidth, 1);
662 scaleTrunc = amdgpu::PackedScaledTruncOp::create(
663 rewriter, loc, truncScaleResultType, slice, uniformScale,
664 j / opInWidth, scaleTrunc);
665 }
666 }
667 if (outSliceWidth != opOutWidth) {
668 scaleTrunc = vector::ExtractStridedSliceOp::create(
669 rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
670 }
671 blockResult = vector::InsertStridedSliceOp::create(
672 rewriter, loc, scaleTrunc, blockResult, i, 1);
673 }
674
675 VectorType resultType = VectorType::get(ratio, outType);
676 Value cast =
677 vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
678 result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
679 offsets, strides);
680 }
681
682 rewriter.replaceOp(op, result);
683
684 return success();
685}
686
688 RewritePatternSet &patterns, bool convertFP8Arithmetic,
689 bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
690 Chipset chipset, PatternBenefit benefit) {
691
692 if (convertFP8Arithmetic) {
693 patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
694 benefit);
695 patterns.add<TruncFToFloat8RewritePattern>(
696 patterns.getContext(), saturateFP8Truncf, chipset, benefit);
697 }
698 if (allowPackedF16Rtz)
699 patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
700
701 if (supportsScaledExtTrunc) {
702 patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
703 patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
704 }
705}
706
707void ArithToAMDGPUConversionPass::runOnOperation() {
708 Operation *op = getOperation();
709 MLIRContext *ctx = &getContext();
711 FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
712 if (failed(maybeChipset)) {
713 emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
714 return signalPassFailure();
715 }
716
717 bool convertFP8Arithmetic =
718 *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
719 bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
721 patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
722 supportsScaledExtTrunc, *maybeChipset);
723 if (failed(applyPatternsGreedily(op, std::move(patterns))))
724 return signalPassFailure();
725}
constexpr Chipset kGfx942
constexpr Chipset kGfx950
return success()
static Value getOriginalVectorValue(Value value)
Get the broadcasted / splatted value for a chain of ops.
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
b getContext())
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatType getF32Type()
Definition Builders.cpp:43
FloatAttr getFloatAttr(Type type, double value)
Definition Builders.cpp:254
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
bool hasOcpFp8(const Chipset &chipset)
Definition Chipset.h:52
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition Utils.cpp:270
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition Chipset.cpp:14
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.