MLIR  22.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 namespace {
36 // Define commonly used chipsets versions for convenience.
37 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
38 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
39 
40 struct ArithToAMDGPUConversionPass final
41  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42  using impl::ArithToAMDGPUConversionPassBase<
43  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
44 
45  void runOnOperation() override;
46 };
47 
48 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
49  using Base::Base;
50 
51  Chipset chipset;
52  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
53  PatternBenefit benefit)
54  : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
55 
56  LogicalResult matchAndRewrite(arith::ExtFOp op,
57  PatternRewriter &rewriter) const override;
58 };
59 
60 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
61  bool saturateFP8 = false;
62  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
63  Chipset chipset, PatternBenefit benefit)
64  : OpRewritePattern::OpRewritePattern(ctx, benefit),
65  saturateFP8(saturateFP8), chipset(chipset) {}
66  Chipset chipset;
67 
68  LogicalResult matchAndRewrite(arith::TruncFOp op,
69  PatternRewriter &rewriter) const override;
70 };
71 
72 struct TruncfToFloat16RewritePattern final
73  : public OpRewritePattern<arith::TruncFOp> {
74 
75  using Base::Base;
76 
77  LogicalResult matchAndRewrite(arith::TruncFOp op,
78  PatternRewriter &rewriter) const override;
79 };
80 
81 struct ScalingExtFRewritePattern final
82  : OpRewritePattern<arith::ScalingExtFOp> {
83  using Base::Base;
84 
85  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
86  PatternRewriter &rewriter) const override;
87 };
88 
89 struct ScalingTruncFRewritePattern final
90  : OpRewritePattern<arith::ScalingTruncFOp> {
91  using Base::Base;
92 
93  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
94  PatternRewriter &rewriter) const override;
95 };
96 
97 } // end namespace
98 
99 static bool isSupportedF8(Type elementType, Chipset chipset) {
100  if (chipset == kGfx942)
101  return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
102  if (hasOcpFp8(chipset))
103  return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
104  return false;
105 }
106 
107 static Value castF32To(Type desType, Value f32, Location loc,
108  PatternRewriter &rewriter) {
109  Type elementType = getElementTypeOrSelf(desType);
110  if (elementType.isF32())
111  return f32;
112  if (elementType.getIntOrFloatBitWidth() < 32)
113  return arith::TruncFOp::create(rewriter, loc, desType, f32);
114  if (elementType.getIntOrFloatBitWidth() > 32)
115  return arith::ExtFOp::create(rewriter, loc, desType, f32);
116  llvm_unreachable("The only 32-bit float type is f32");
117 }
118 
119 LogicalResult
120 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
121  PatternRewriter &rewriter) const {
122  Type inType = op.getIn().getType();
123  auto inVecType = dyn_cast<VectorType>(inType);
124  if (inVecType) {
125  if (inVecType.isScalable())
126  return failure();
127  inType = inVecType.getElementType();
128  }
129  if (!isSupportedF8(inType, chipset))
130  return failure();
131 
132  Location loc = op.getLoc();
133  Value in = op.getIn();
134  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
135  VectorType extResType = VectorType::get(2, rewriter.getF32Type());
136  if (!inVecType) {
137  Value asFloat = amdgpu::ExtPackedFp8Op::create(
138  rewriter, loc, rewriter.getF32Type(), in, 0);
139  Value result = castF32To(outElemType, asFloat, loc, rewriter);
140  rewriter.replaceOp(op, result);
141  return success();
142  }
143  int64_t numElements = inVecType.getNumElements();
144 
145  Value zero = arith::ConstantOp::create(
146  rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
147  VectorType outType = cast<VectorType>(op.getOut().getType());
148 
149  if (inVecType.getShape().empty()) {
150  Value zerodSplat =
151  rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
152  Value scalarIn =
153  vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
154  Value scalarExt =
155  arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156  Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
157  zerodSplat, ArrayRef<int64_t>{});
158  rewriter.replaceOp(op, result);
159  return success();
160  }
161 
162  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
163  outType.getElementType());
164  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
165 
166  if (inVecType.getRank() > 1) {
167  inVecType = VectorType::get(SmallVector<int64_t>{numElements},
168  inVecType.getElementType());
169  in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
170  }
171 
172  for (int64_t i = 0; i < numElements; i += 4) {
173  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174  Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
175  elemsThisOp, 1);
176  for (int64_t j = 0; j < elemsThisOp; j += 2) {
177  if (i + j + 1 < numElements) { // Convert two 8-bit elements
178  Value asFloats = amdgpu::ExtPackedFp8Op::create(
179  rewriter, loc, extResType, inSlice, j / 2);
180  Type desType = VectorType::get(2, outElemType);
181  Value asType = castF32To(desType, asFloats, loc, rewriter);
182  result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
183  result, i + j, 1);
184  } else { // Convert a 8-bit element
185  Value asFloat = amdgpu::ExtPackedFp8Op::create(
186  rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
187  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
188  result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
189  }
190  }
191  }
192 
193  if (inVecType.getRank() != outType.getRank()) {
194  result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
195  }
196 
197  rewriter.replaceOp(op, result);
198  return success();
199 }
200 
201 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
202  Type type = value.getType();
203  if (type.isF32())
204  return value;
205  if (type.getIntOrFloatBitWidth() < 32)
206  return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
207  if (type.getIntOrFloatBitWidth() > 32)
208  return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
209  llvm_unreachable("The only 32-bit float type is f32");
210 }
211 
212 // If `in` is a finite value, clamp it between the maximum and minimum values
213 // of `outElemType` so that subsequent conversion instructions don't
214 // overflow those out-of-range values to NaN. These semantics are commonly
215 // used in machine-learning contexts where failure to clamp would lead to
216 // excessive NaN production.
217 static Value clampInput(PatternRewriter &rewriter, Location loc,
218  Type outElemType, Value source) {
219  Type sourceType = source.getType();
220  const llvm::fltSemantics &sourceSem =
221  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
222  const llvm::fltSemantics &targetSem =
223  cast<FloatType>(outElemType).getFloatSemantics();
224 
225  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
226  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
227  bool ignoredLosesInfo = false;
228  // We can ignore conversion failures here because this conversion promotes
229  // from a smaller type to a larger one - ex. there can be no loss of precision
230  // when casting fp8 to f16.
231  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
233 
234  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
235  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
236 
238  rewriter, loc, sourceType,
239  APFloat::getInf(sourceSem, /*Negative=*/false));
241  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
242  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
243  loc, arith::CmpFPredicate::OEQ, source, inf);
244  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
245  loc, arith::CmpFPredicate::OEQ, source, negInf);
246  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
247  loc, arith::CmpFPredicate::UNO, source, source);
248  Value isNonFinite = arith::OrIOp::create(
249  rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
250  isNan);
251 
252  Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
253  Value clamped =
254  arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
255  Value res =
256  arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
257  return res;
258 }
259 
260 LogicalResult
261 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
262  PatternRewriter &rewriter) const {
263  // Only supporting default rounding mode as of now.
264  if (op.getRoundingmodeAttr())
265  return failure();
266  Type outType = op.getOut().getType();
267  auto outVecType = dyn_cast<VectorType>(outType);
268  if (outVecType) {
269  if (outVecType.isScalable())
270  return failure();
271  outType = outVecType.getElementType();
272  }
273  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
274  if (inType && inType.getWidth() <= 8 && saturateFP8)
275  // Conversion between 8-bit floats is not supported with truncation enabled.
276  return failure();
277 
278  if (!isSupportedF8(outType, chipset))
279  return failure();
280 
281  Location loc = op.getLoc();
282  Value in = op.getIn();
283  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
284  if (saturateFP8)
285  in = clampInput(rewriter, loc, outElemType, in);
286  auto inVectorTy = dyn_cast<VectorType>(in.getType());
287  VectorType truncResType = VectorType::get(4, outElemType);
288  if (!inVectorTy) {
289  Value asFloat = castToF32(in, loc, rewriter);
290  Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291  rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
292  /*existing=*/nullptr);
293  Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
294  rewriter.replaceOp(op, result);
295  return success();
296  }
297 
298  int64_t numElements = outVecType.getNumElements();
299  Value zero = arith::ConstantOp::create(
300  rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
301  if (outVecType.getShape().empty()) {
302  Value scalarIn =
303  vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
304  // Recurse to send the 0-D vector case to the 1-D vector case
305  Value scalarTrunc =
306  arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307  Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
309  rewriter.replaceOp(op, result);
310  return success();
311  }
312 
313  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
314  outVecType.getElementType());
315  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
316 
317  if (inVectorTy.getRank() > 1) {
318  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
319  inVectorTy.getElementType());
320  in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
321  }
322 
323  for (int64_t i = 0; i < numElements; i += 4) {
324  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
325  Value thisResult = nullptr;
326  for (int64_t j = 0; j < elemsThisOp; j += 2) {
327  Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
328  Value asFloatA = castToF32(elemA, loc, rewriter);
329  Value asFloatB = nullptr;
330  if (j + 1 < elemsThisOp) {
331  Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
332  asFloatB = castToF32(elemB, loc, rewriter);
333  }
334  thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335  rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
336  }
337  if (elemsThisOp < 4)
338  thisResult = vector::ExtractStridedSliceOp::create(
339  rewriter, loc, thisResult, 0, elemsThisOp, 1);
340  result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
341  result, i, 1);
342  }
343 
344  if (inVectorTy.getRank() != outVecType.getRank()) {
345  result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
346  }
347 
348  rewriter.replaceOp(op, result);
349  return success();
350 }
351 
352 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
353  arith::TruncFOp op, PatternRewriter &rewriter) const {
354  Type outType = op.getOut().getType();
355  Type inputType = getElementTypeOrSelf(op.getIn());
356  auto outVecType = dyn_cast<VectorType>(outType);
357  if (outVecType) {
358  if (outVecType.isScalable())
359  return failure();
360  outType = outVecType.getElementType();
361  }
362  if (!(outType.isF16() && inputType.isF32()))
363  return failure();
364 
365  Location loc = op.getLoc();
366  Value in = op.getIn();
367  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
368  VectorType truncResType = VectorType::get(2, outElemType);
369  auto inVectorTy = dyn_cast<VectorType>(in.getType());
370 
371  // Handle the case where input type is not a vector type
372  if (!inVectorTy) {
373  auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
374  Value asF16s =
375  ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376  Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
377  rewriter.replaceOp(op, result);
378  return success();
379  }
380  int64_t numElements = outVecType.getNumElements();
381  Value zero = rewriter.createOrFold<arith::ConstantOp>(
382  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
383  Value result =
384  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
385 
386  if (inVectorTy.getRank() > 1) {
387  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
388  inVectorTy.getElementType());
389  in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
390  }
391 
392  // Handle the vector case. We also handle the (uncommon) case where the vector
393  // length is odd
394  for (int64_t i = 0; i < numElements; i += 2) {
395  int64_t elemsThisOp = std::min(numElements, i + 2) - i;
396  Value thisResult = nullptr;
397  Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398  Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
399 
400  if (elemsThisOp == 2) {
401  elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
402  }
403 
404  thisResult =
405  ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
406  // Place back the truncated result into the possibly larger vector. If we
407  // are operating on a size 2 vector, these operations should be folded away
408  thisResult = vector::ExtractStridedSliceOp::create(
409  rewriter, loc, thisResult, 0, elemsThisOp, 1);
410  result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
411  result, i, 1);
412  }
413 
414  if (inVectorTy.getRank() != outVecType.getRank()) {
415  result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
416  }
417 
418  rewriter.replaceOp(op, result);
419  return success();
420 }
421 
422 /// Get the broadcasted / splatted value for a chain of ops.
424  Value current = value;
425  while (Operation *definingOp = current.getDefiningOp()) {
426  bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
427  .Case<vector::ShapeCastOp>([&current](auto op) {
428  current = op.getSource();
429  return true;
430  })
431  .Case<vector::BroadcastOp>([&current](auto op) {
432  current = op.getSource();
433  return false;
434  })
435  .Default([](Operation *) { return false; });
436 
437  if (!skipOp) {
438  break;
439  }
440  }
441  return current;
442 }
443 
444 LogicalResult
445 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
446  PatternRewriter &rewriter) const {
447  Location loc = op.getLoc();
448  constexpr int64_t opOutWidth = 2;
449 
450  Value in = op.getIn();
451  Value scale = op.getScale();
452  Value out = op.getOut();
453 
454  Type f32 = rewriter.getF32Type();
455  Type inType = getElementTypeOrSelf(in);
456  Type scaleType = getElementTypeOrSelf(scale);
457  Type outType = getElementTypeOrSelf(out);
458 
459  int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
460 
461  VectorType outVecType = dyn_cast<VectorType>(out.getType());
462  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
463 
464  if (outVecType && outVecType.isScalable())
465  return failure();
466 
467  Type scaleF32Type =
468  scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
469  if (scaleType.getIntOrFloatBitWidth() < 32)
470  scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
471  else if (scaleType.getIntOrFloatBitWidth() > 32)
472  scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
473 
474  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
475 
476  if (!outVecType) {
477  Value inCast = vector::BroadcastOp::create(rewriter, loc,
478  VectorType::get(1, inType), in);
479  // TODO: replace this with non-packed ScaledExtOp
480  Value scaleExt = amdgpu::ScaledExtPackedOp::create(
481  rewriter, loc, extScaleResultType, inCast, scale, 0);
482  scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
483  return success();
484  }
485 
486  VectorType inVecType = cast<VectorType>(in.getType());
487  Value origScale = getOriginalVectorValue(op.getScale());
488  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
489 
490  ArrayRef<int64_t> inShape = inVecType.getShape();
491  SmallVector<int64_t> originalScaleShape;
492  if (origScaleVecType)
493  llvm::append_range(originalScaleShape, origScaleVecType.getShape());
494 
495  originalScaleShape.insert(originalScaleShape.end(),
496  inShape.size() - originalScaleShape.size(), 1);
497 
498  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
499  assert(maybeRatio &&
500  "failed to derive block size from broadcast or splat operation");
501 
502  SmallVector<int64_t> ratio =
503  maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
504 
505  int64_t blockSize = computeProduct(ratio);
506 
507  Value zero = arith::ConstantOp::create(rewriter, loc, outType,
508  rewriter.getFloatAttr(outType, 0.0));
509  Value result =
510  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
511 
512  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
513  SmallVector<int64_t> strides(offsets.size(), 1);
514  Value block = vector::ExtractStridedSliceOp::create(
515  rewriter, loc, in, offsets, ratio, strides);
516  VectorType block1DType = VectorType::get(blockSize, inType);
517  Value block1D =
518  vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
519  Value uniformScale =
520  vector::ExtractOp::create(rewriter, loc, scale, offsets);
521 
522  VectorType blockResultType = VectorType::get(blockSize, outType);
523  Value blockResult =
524  rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
525 
526  for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
527  i < blockSize;
528  i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
529  Value inSlice = vector::ExtractStridedSliceOp::create(
530  rewriter, loc, block1D, i, inSliceWidth, 1);
531  for (int64_t j = 0,
532  outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
533  j < inSliceWidth; j += outSliceWidth,
534  outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
535  // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
536  Value scaleExt = amdgpu::ScaledExtPackedOp::create(
537  rewriter, loc, extScaleResultType, inSlice, uniformScale,
538  j / opOutWidth);
539  if (outSliceWidth < opOutWidth) {
540  scaleExt = vector::ExtractStridedSliceOp::create(
541  rewriter, loc, scaleExt, 0, outSliceWidth, 1);
542  }
543  blockResult = vector::InsertStridedSliceOp::create(
544  rewriter, loc, scaleExt, blockResult, i + j, 1);
545  }
546  }
547 
548  VectorType resultType = VectorType::get(ratio, outType);
549  Value cast =
550  vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
551  result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
552  offsets, strides);
553  }
554 
555  rewriter.replaceOp(op, result);
556 
557  return success();
558 }
559 
560 LogicalResult
561 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
562  PatternRewriter &rewriter) const {
563  Location loc = op.getLoc();
564  constexpr int64_t opInWidth = 2;
565 
566  Value in = op.getIn();
567  Value scale = op.getScale();
568  Value out = op.getOut();
569 
570  Type f32 = rewriter.getF32Type();
571  Type inType = getElementTypeOrSelf(in);
572  Type scaleType = getElementTypeOrSelf(scale);
573  Type outType = getElementTypeOrSelf(out);
574 
575  VectorType outVecType = dyn_cast<VectorType>(out.getType());
576  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
577  if (outVecType && outVecType.isScalable())
578  return failure();
579 
580  Type scaleF32Type =
581  scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
582  if (scaleType.getIntOrFloatBitWidth() < 32)
583  scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
584  else if (scaleType.getIntOrFloatBitWidth() > 32)
585  scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
586 
587  Value zero = arith::ConstantOp::create(rewriter, loc, outType,
588  rewriter.getFloatAttr(outType, 0.0));
589  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
590  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
591 
592  if (!outVecType) {
593  Type inVecType = VectorType::get(1, inType);
594  Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
595  // TODO: replace this with non-packed ScaledTruncOp
596  Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
597  rewriter, loc, truncScaleResultType, inCast, scale, 0,
598  /*existing=*/nullptr);
599  scaleTrunc =
600  rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
601  return success();
602  }
603 
604  VectorType inVecType = cast<VectorType>(in.getType());
605  Value origScale = getOriginalVectorValue(op.getScale());
606  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
607 
608  ArrayRef<int64_t> inShape = inVecType.getShape();
609  SmallVector<int64_t> scaleShape;
610  if (origScaleVecType)
611  llvm::append_range(scaleShape, origScaleVecType.getShape());
612 
613  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
614 
615  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
616  assert(maybeRatio &&
617  "failed to derive block size from broadcast or splat operation");
618 
619  SmallVector<int64_t> ratio =
620  maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
621 
622  int64_t blockSize = computeProduct(ratio);
623 
624  Value result =
625  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
626 
627  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
628  SmallVector<int64_t> strides(offsets.size(), 1);
629  Value block = vector::ExtractStridedSliceOp::create(
630  rewriter, loc, in, offsets, ratio, strides);
631  VectorType block1DType = VectorType::get(blockSize, inType);
632  Value block1D =
633  vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
634  Value uniformScale =
635  vector::ExtractOp::create(rewriter, loc, scale, offsets);
636 
637  VectorType blockResultType = VectorType::get(blockSize, outType);
638  Value blockResult =
639  rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
640 
641  for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
642  i < blockSize; i += outSliceWidth,
643  outSliceWidth = std::min(opOutWidth, blockSize - i)) {
644  Value scaleTrunc;
645  // Case where <= 2 elements are being truncated.
646  if (outSliceWidth <= opInWidth) {
647  Value slice = vector::ExtractStridedSliceOp::create(
648  rewriter, loc, block1D, i, outSliceWidth, 1);
649  // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
650  scaleTrunc = amdgpu::PackedScaledTruncOp::create(
651  rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
652  /*existing=*/nullptr);
653  } else {
654  scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
655  truncScaleResultType, zero);
656  for (int64_t j = 0,
657  inSliceWidth = std::min(opInWidth, outSliceWidth - j);
658  j < outSliceWidth; j += opInWidth,
659  inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
660  Value slice = vector::ExtractStridedSliceOp::create(
661  rewriter, loc, block1D, i + j, inSliceWidth, 1);
662  scaleTrunc = amdgpu::PackedScaledTruncOp::create(
663  rewriter, loc, truncScaleResultType, slice, uniformScale,
664  j / opInWidth, scaleTrunc);
665  }
666  }
667  if (outSliceWidth != opOutWidth) {
668  scaleTrunc = vector::ExtractStridedSliceOp::create(
669  rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
670  }
671  blockResult = vector::InsertStridedSliceOp::create(
672  rewriter, loc, scaleTrunc, blockResult, i, 1);
673  }
674 
675  VectorType resultType = VectorType::get(ratio, outType);
676  Value cast =
677  vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
678  result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
679  offsets, strides);
680  }
681 
682  rewriter.replaceOp(op, result);
683 
684  return success();
685 }
686 
688  RewritePatternSet &patterns, bool convertFP8Arithmetic,
689  bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc,
690  Chipset chipset, PatternBenefit benefit) {
691 
692  if (convertFP8Arithmetic) {
693  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
694  benefit);
695  patterns.add<TruncFToFloat8RewritePattern>(
696  patterns.getContext(), saturateFP8Truncf, chipset, benefit);
697  }
698  if (allowPackedF16Rtz)
699  patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
700 
701  if (supportsScaledExtTrunc) {
702  patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
703  patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
704  }
705 }
706 
707 void ArithToAMDGPUConversionPass::runOnOperation() {
708  Operation *op = getOperation();
709  MLIRContext *ctx = &getContext();
711  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
712  if (failed(maybeChipset)) {
713  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
714  return signalPassFailure();
715  }
716 
717  bool convertFP8Arithmetic =
718  *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
719  bool supportsScaledExtTrunc = *maybeChipset == kGfx950;
721  patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
722  supportsScaledExtTrunc, *maybeChipset);
723  if (failed(applyPatternsGreedily(op, std::move(patterns))))
724  return signalPassFailure();
725 }
constexpr Chipset kGfx942
constexpr Chipset kGfx950
static Value getOriginalVectorValue(Value value)
Get the broadcasted / splatted value for a chain of ops.
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatType getF32Type()
Definition: Builders.cpp:43
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:254
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, bool supportsScaledExtTrunc, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:270
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.