MLIR  22.0.0git
ArithToAMDGPU.cpp
Go to the documentation of this file.
1 //===- ArithToAMDGPU.cpp - Arith to AMDGPU dialect conversion ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
21 #include "mlir/IR/BuiltinTypes.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/IR/TypeUtilities.h"
24 #include "mlir/Pass/Pass.h"
26 
27 namespace mlir {
28 #define GEN_PASS_DEF_ARITHTOAMDGPUCONVERSIONPASS
29 #include "mlir/Conversion/Passes.h.inc"
30 } // namespace mlir
31 
32 using namespace mlir;
33 using namespace mlir::amdgpu;
34 
35 namespace {
36 // Define commonly used chipsets versions for convenience.
37 constexpr Chipset kGfx942 = Chipset(9, 4, 2);
38 constexpr Chipset kGfx950 = Chipset(9, 5, 0);
39 
40 struct ArithToAMDGPUConversionPass final
41  : impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
42  using impl::ArithToAMDGPUConversionPassBase<
43  ArithToAMDGPUConversionPass>::ArithToAMDGPUConversionPassBase;
44 
45  void runOnOperation() override;
46 };
47 
48 struct ExtFOnFloat8RewritePattern final : OpRewritePattern<arith::ExtFOp> {
50 
51  Chipset chipset;
52  ExtFOnFloat8RewritePattern(MLIRContext *ctx, Chipset chipset,
53  PatternBenefit benefit)
54  : OpRewritePattern::OpRewritePattern(ctx, benefit), chipset(chipset) {}
55 
56  LogicalResult matchAndRewrite(arith::ExtFOp op,
57  PatternRewriter &rewriter) const override;
58 };
59 
60 struct TruncFToFloat8RewritePattern final : OpRewritePattern<arith::TruncFOp> {
61  bool saturateFP8 = false;
62  TruncFToFloat8RewritePattern(MLIRContext *ctx, bool saturateFP8,
63  Chipset chipset, PatternBenefit benefit)
64  : OpRewritePattern::OpRewritePattern(ctx, benefit),
65  saturateFP8(saturateFP8), chipset(chipset) {}
66  Chipset chipset;
67 
68  LogicalResult matchAndRewrite(arith::TruncFOp op,
69  PatternRewriter &rewriter) const override;
70 };
71 
72 struct TruncfToFloat16RewritePattern final
73  : public OpRewritePattern<arith::TruncFOp> {
74 
76 
77  LogicalResult matchAndRewrite(arith::TruncFOp op,
78  PatternRewriter &rewriter) const override;
79 };
80 
81 struct ScalingExtFRewritePattern final
82  : OpRewritePattern<arith::ScalingExtFOp> {
84 
85  LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
86  PatternRewriter &rewriter) const override;
87 };
88 
89 struct ScalingTruncFRewritePattern final
90  : OpRewritePattern<arith::ScalingTruncFOp> {
92 
93  LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
94  PatternRewriter &rewriter) const override;
95 };
96 
97 } // end namespace
98 
99 static bool isSupportedF8(Type elementType, Chipset chipset) {
100  if (chipset == kGfx942)
101  return isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(elementType);
102  if (hasOcpFp8(chipset))
103  return isa<Float8E4M3FNType, Float8E5M2Type>(elementType);
104  return false;
105 }
106 
107 static Value castF32To(Type desType, Value f32, Location loc,
108  PatternRewriter &rewriter) {
109  Type elementType = getElementTypeOrSelf(desType);
110  if (elementType.isF32())
111  return f32;
112  if (elementType.getIntOrFloatBitWidth() < 32)
113  return arith::TruncFOp::create(rewriter, loc, desType, f32);
114  if (elementType.getIntOrFloatBitWidth() > 32)
115  return arith::ExtFOp::create(rewriter, loc, desType, f32);
116  llvm_unreachable("The only 32-bit float type is f32");
117 }
118 
119 LogicalResult
120 ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
121  PatternRewriter &rewriter) const {
122  Type inType = op.getIn().getType();
123  auto inVecType = dyn_cast<VectorType>(inType);
124  if (inVecType) {
125  if (inVecType.isScalable())
126  return failure();
127  inType = inVecType.getElementType();
128  }
129  if (!isSupportedF8(inType, chipset))
130  return failure();
131 
132  Location loc = op.getLoc();
133  Value in = op.getIn();
134  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
135  VectorType extResType = VectorType::get(2, rewriter.getF32Type());
136  if (!inVecType) {
137  Value asFloat = amdgpu::ExtPackedFp8Op::create(
138  rewriter, loc, rewriter.getF32Type(), in, 0);
139  Value result = castF32To(outElemType, asFloat, loc, rewriter);
140  rewriter.replaceOp(op, result);
141  return success();
142  }
143  int64_t numElements = inVecType.getNumElements();
144 
145  Value zero = arith::ConstantOp::create(
146  rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
147  VectorType outType = cast<VectorType>(op.getOut().getType());
148 
149  if (inVecType.getShape().empty()) {
150  Value zerodSplat =
151  rewriter.createOrFold<vector::BroadcastOp>(loc, outType, zero);
152  Value scalarIn =
153  vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
154  Value scalarExt =
155  arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
156  Value result = vector::InsertOp::create(rewriter, loc, scalarExt,
157  zerodSplat, ArrayRef<int64_t>{});
158  rewriter.replaceOp(op, result);
159  return success();
160  }
161 
162  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
163  outType.getElementType());
164  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
165 
166  if (inVecType.getRank() > 1) {
167  inVecType = VectorType::get(SmallVector<int64_t>{numElements},
168  inVecType.getElementType());
169  in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
170  }
171 
172  for (int64_t i = 0; i < numElements; i += 4) {
173  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
174  Value inSlice = vector::ExtractStridedSliceOp::create(rewriter, loc, in, i,
175  elemsThisOp, 1);
176  for (int64_t j = 0; j < elemsThisOp; j += 2) {
177  if (i + j + 1 < numElements) { // Convert two 8-bit elements
178  Value asFloats = amdgpu::ExtPackedFp8Op::create(
179  rewriter, loc, extResType, inSlice, j / 2);
180  Type desType = VectorType::get(2, outElemType);
181  Value asType = castF32To(desType, asFloats, loc, rewriter);
182  result = vector::InsertStridedSliceOp::create(rewriter, loc, asType,
183  result, i + j, 1);
184  } else { // Convert a 8-bit element
185  Value asFloat = amdgpu::ExtPackedFp8Op::create(
186  rewriter, loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
187  Value asType = castF32To(outElemType, asFloat, loc, rewriter);
188  result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
189  }
190  }
191  }
192 
193  if (inVecType.getRank() != outType.getRank()) {
194  result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
195  }
196 
197  rewriter.replaceOp(op, result);
198  return success();
199 }
200 
201 static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
202  Type type = value.getType();
203  if (type.isF32())
204  return value;
205  if (type.getIntOrFloatBitWidth() < 32)
206  return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
207  if (type.getIntOrFloatBitWidth() > 32)
208  return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
209  llvm_unreachable("The only 32-bit float type is f32");
210 }
211 
212 // If `in` is a finite value, clamp it between the maximum and minimum values
213 // of `outElemType` so that subsequent conversion instructions don't
214 // overflow those out-of-range values to NaN. These semantics are commonly
215 // used in machine-learning contexts where failure to clamp would lead to
216 // excessive NaN production.
217 static Value clampInput(PatternRewriter &rewriter, Location loc,
218  Type outElemType, Value source) {
219  Type sourceType = source.getType();
220  const llvm::fltSemantics &sourceSem =
221  cast<FloatType>(getElementTypeOrSelf(sourceType)).getFloatSemantics();
222  const llvm::fltSemantics &targetSem =
223  cast<FloatType>(outElemType).getFloatSemantics();
224 
225  APFloat min = APFloat::getLargest(targetSem, /*Negative=*/true);
226  APFloat max = APFloat::getLargest(targetSem, /*Negative=*/false);
227  bool ignoredLosesInfo = false;
228  // We can ignore conversion failures here because this conversion promotes
229  // from a smaller type to a larger one - ex. there can be no loss of precision
230  // when casting fp8 to f16.
231  (void)min.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
232  (void)max.convert(sourceSem, APFloat::rmNearestTiesToEven, &ignoredLosesInfo);
233 
234  Value minCst = createScalarOrSplatConstant(rewriter, loc, sourceType, min);
235  Value maxCst = createScalarOrSplatConstant(rewriter, loc, sourceType, max);
236 
238  rewriter, loc, sourceType,
239  APFloat::getInf(sourceSem, /*Negative=*/false));
241  rewriter, loc, sourceType, APFloat::getInf(sourceSem, /*Negative=*/true));
242  Value isInf = rewriter.createOrFold<arith::CmpFOp>(
243  loc, arith::CmpFPredicate::OEQ, source, inf);
244  Value isNegInf = rewriter.createOrFold<arith::CmpFOp>(
245  loc, arith::CmpFPredicate::OEQ, source, negInf);
246  Value isNan = rewriter.createOrFold<arith::CmpFOp>(
247  loc, arith::CmpFPredicate::UNO, source, source);
248  Value isNonFinite = arith::OrIOp::create(
249  rewriter, loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf),
250  isNan);
251 
252  Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
253  Value clamped =
254  arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
255  Value res =
256  arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
257  return res;
258 }
259 
260 LogicalResult
261 TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
262  PatternRewriter &rewriter) const {
263  // Only supporting default rounding mode as of now.
264  if (op.getRoundingmodeAttr())
265  return failure();
266  Type outType = op.getOut().getType();
267  auto outVecType = dyn_cast<VectorType>(outType);
268  if (outVecType) {
269  if (outVecType.isScalable())
270  return failure();
271  outType = outVecType.getElementType();
272  }
273  auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
274  if (inType && inType.getWidth() <= 8 && saturateFP8)
275  // Conversion between 8-bit floats is not supported with truncation enabled.
276  return failure();
277 
278  if (!isSupportedF8(outType, chipset))
279  return failure();
280 
281  Location loc = op.getLoc();
282  Value in = op.getIn();
283  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
284  if (saturateFP8)
285  in = clampInput(rewriter, loc, outElemType, in);
286  auto inVectorTy = dyn_cast<VectorType>(in.getType());
287  VectorType truncResType = VectorType::get(4, outElemType);
288  if (!inVectorTy) {
289  Value asFloat = castToF32(in, loc, rewriter);
290  Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(
291  rewriter, loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
292  /*existing=*/nullptr);
293  Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
294  rewriter.replaceOp(op, result);
295  return success();
296  }
297 
298  int64_t numElements = outVecType.getNumElements();
299  Value zero = arith::ConstantOp::create(
300  rewriter, loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
301  if (outVecType.getShape().empty()) {
302  Value scalarIn =
303  vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
304  // Recurse to send the 0-D vector case to the 1-D vector case
305  Value scalarTrunc =
306  arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
307  Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
309  rewriter.replaceOp(op, result);
310  return success();
311  }
312 
313  VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
314  outVecType.getElementType());
315  Value result = rewriter.createOrFold<vector::BroadcastOp>(loc, flatTy, zero);
316 
317  if (inVectorTy.getRank() > 1) {
318  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
319  inVectorTy.getElementType());
320  in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
321  }
322 
323  for (int64_t i = 0; i < numElements; i += 4) {
324  int64_t elemsThisOp = std::min(numElements, i + 4) - i;
325  Value thisResult = nullptr;
326  for (int64_t j = 0; j < elemsThisOp; j += 2) {
327  Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
328  Value asFloatA = castToF32(elemA, loc, rewriter);
329  Value asFloatB = nullptr;
330  if (j + 1 < elemsThisOp) {
331  Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
332  asFloatB = castToF32(elemB, loc, rewriter);
333  }
334  thisResult = amdgpu::PackedTrunc2xFp8Op::create(
335  rewriter, loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
336  }
337  if (elemsThisOp < 4)
338  thisResult = vector::ExtractStridedSliceOp::create(
339  rewriter, loc, thisResult, 0, elemsThisOp, 1);
340  result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
341  result, i, 1);
342  }
343 
344  if (inVectorTy.getRank() != outVecType.getRank()) {
345  result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
346  }
347 
348  rewriter.replaceOp(op, result);
349  return success();
350 }
351 
352 LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
353  arith::TruncFOp op, PatternRewriter &rewriter) const {
354  Type outType = op.getOut().getType();
355  Type inputType = getElementTypeOrSelf(op.getIn());
356  auto outVecType = dyn_cast<VectorType>(outType);
357  if (outVecType) {
358  if (outVecType.isScalable())
359  return failure();
360  outType = outVecType.getElementType();
361  }
362  if (!(outType.isF16() && inputType.isF32()))
363  return failure();
364 
365  Location loc = op.getLoc();
366  Value in = op.getIn();
367  Type outElemType = getElementTypeOrSelf(op.getOut().getType());
368  VectorType truncResType = VectorType::get(2, outElemType);
369  auto inVectorTy = dyn_cast<VectorType>(in.getType());
370 
371  // Handle the case where input type is not a vector type
372  if (!inVectorTy) {
373  auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
374  Value asF16s =
375  ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
376  Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
377  rewriter.replaceOp(op, result);
378  return success();
379  }
380  int64_t numElements = outVecType.getNumElements();
381  Value zero = rewriter.createOrFold<arith::ConstantOp>(
382  loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
383  Value result =
384  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
385 
386  if (inVectorTy.getRank() > 1) {
387  inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
388  inVectorTy.getElementType());
389  in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
390  }
391 
392  // Handle the vector case. We also handle the (uncommon) case where the vector
393  // length is odd
394  for (int64_t i = 0; i < numElements; i += 2) {
395  int64_t elemsThisOp = std::min(numElements, i + 2) - i;
396  Value thisResult = nullptr;
397  Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
398  Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
399 
400  if (elemsThisOp == 2) {
401  elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
402  }
403 
404  thisResult =
405  ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
406  // Place back the truncated result into the possibly larger vector. If we
407  // are operating on a size 2 vector, these operations should be folded away
408  thisResult = vector::ExtractStridedSliceOp::create(
409  rewriter, loc, thisResult, 0, elemsThisOp, 1);
410  result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
411  result, i, 1);
412  }
413 
414  if (inVectorTy.getRank() != outVecType.getRank()) {
415  result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
416  }
417 
418  rewriter.replaceOp(op, result);
419  return success();
420 }
421 
422 /// Get the broadcasted / splatted value for a chain of ops.
424  Value current = value;
425  while (Operation *definingOp = current.getDefiningOp()) {
426  bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
427  .Case<vector::ShapeCastOp>([&current](auto op) {
428  current = op.getSource();
429  return true;
430  })
431  .Case<vector::BroadcastOp>([&current](auto op) {
432  current = op.getSource();
433  return false;
434  })
435  .Case<vector::SplatOp>([&current](auto op) {
436  current = op.getInput();
437  return false;
438  })
439  .Default([](Operation *) { return false; });
440 
441  if (!skipOp) {
442  break;
443  }
444  }
445  return current;
446 }
447 
448 LogicalResult
449 ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
450  PatternRewriter &rewriter) const {
451  Location loc = op.getLoc();
452  constexpr int64_t opOutWidth = 2;
453 
454  Value in = op.getIn();
455  Value scale = op.getScale();
456  Value out = op.getOut();
457 
458  Type f32 = rewriter.getF32Type();
459  Type inType = getElementTypeOrSelf(in);
460  Type scaleType = getElementTypeOrSelf(scale);
461  Type outType = getElementTypeOrSelf(out);
462 
463  int64_t opInWidth = 32 / inType.getIntOrFloatBitWidth();
464 
465  VectorType outVecType = dyn_cast<VectorType>(out.getType());
466  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
467 
468  if (outVecType && outVecType.isScalable())
469  return failure();
470 
471  Type scaleF32Type =
472  scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
473  if (scaleType.getIntOrFloatBitWidth() < 32)
474  scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
475  else if (scaleType.getIntOrFloatBitWidth() > 32)
476  scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
477 
478  VectorType extScaleResultType = VectorType::get(opOutWidth, outType);
479 
480  if (!outVecType) {
481  Value inCast = vector::BroadcastOp::create(rewriter, loc,
482  VectorType::get(1, inType), in);
483  // TODO: replace this with non-packed ScaledExtOp
484  Value scaleExt = amdgpu::ScaledExtPackedOp::create(
485  rewriter, loc, extScaleResultType, inCast, scale, 0);
486  scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
487  return success();
488  }
489 
490  VectorType inVecType = cast<VectorType>(in.getType());
491  Value origScale = getOriginalVectorValue(op.getScale());
492  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
493 
494  ArrayRef<int64_t> inShape = inVecType.getShape();
495  SmallVector<int64_t> originalScaleShape;
496  if (origScaleVecType)
497  llvm::append_range(originalScaleShape, origScaleVecType.getShape());
498 
499  originalScaleShape.insert(originalScaleShape.end(),
500  inShape.size() - originalScaleShape.size(), 1);
501 
502  auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
503  assert(maybeRatio &&
504  "failed to derive block size from broadcast or splat operation");
505 
506  SmallVector<int64_t> ratio =
507  maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
508 
509  int64_t blockSize = computeProduct(ratio);
510 
511  Value zero = arith::ConstantOp::create(rewriter, loc, outType,
512  rewriter.getFloatAttr(outType, 0.0));
513  Value result =
514  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
515 
516  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
517  SmallVector<int64_t> strides(offsets.size(), 1);
518  Value block = vector::ExtractStridedSliceOp::create(
519  rewriter, loc, in, offsets, ratio, strides);
520  VectorType block1DType = VectorType::get(blockSize, inType);
521  Value block1D =
522  vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
523  Value uniformScale =
524  vector::ExtractOp::create(rewriter, loc, scale, offsets);
525 
526  VectorType blockResultType = VectorType::get(blockSize, outType);
527  Value blockResult =
528  rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
529 
530  for (int64_t i = 0, inSliceWidth = std::min(opInWidth, blockSize - i);
531  i < blockSize;
532  i += inSliceWidth, inSliceWidth = std::min(opInWidth, blockSize - i)) {
533  Value inSlice = vector::ExtractStridedSliceOp::create(
534  rewriter, loc, block1D, i, inSliceWidth, 1);
535  for (int64_t j = 0,
536  outSliceWidth = std::min(opOutWidth, inSliceWidth - j);
537  j < inSliceWidth; j += outSliceWidth,
538  outSliceWidth = std::min(opOutWidth, inSliceWidth - j)) {
539  // TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
540  Value scaleExt = amdgpu::ScaledExtPackedOp::create(
541  rewriter, loc, extScaleResultType, inSlice, uniformScale,
542  j / opOutWidth);
543  if (outSliceWidth < opOutWidth) {
544  scaleExt = vector::ExtractStridedSliceOp::create(
545  rewriter, loc, scaleExt, 0, outSliceWidth, 1);
546  }
547  blockResult = vector::InsertStridedSliceOp::create(
548  rewriter, loc, scaleExt, blockResult, i + j, 1);
549  }
550  }
551 
552  VectorType resultType = VectorType::get(ratio, outType);
553  Value cast =
554  vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
555  result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
556  offsets, strides);
557  }
558 
559  rewriter.replaceOp(op, result);
560 
561  return success();
562 }
563 
564 LogicalResult
565 ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
566  PatternRewriter &rewriter) const {
567  Location loc = op.getLoc();
568  constexpr int64_t opInWidth = 2;
569 
570  Value in = op.getIn();
571  Value scale = op.getScale();
572  Value out = op.getOut();
573 
574  Type f32 = rewriter.getF32Type();
575  Type inType = getElementTypeOrSelf(in);
576  Type scaleType = getElementTypeOrSelf(scale);
577  Type outType = getElementTypeOrSelf(out);
578 
579  VectorType outVecType = dyn_cast<VectorType>(out.getType());
580  VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());
581  if (outVecType && outVecType.isScalable())
582  return failure();
583 
584  Type scaleF32Type =
585  scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
586  if (scaleType.getIntOrFloatBitWidth() < 32)
587  scale = arith::ExtFOp::create(rewriter, loc, scaleF32Type, scale);
588  else if (scaleType.getIntOrFloatBitWidth() > 32)
589  scale = arith::TruncFOp::create(rewriter, loc, scaleF32Type, scale);
590 
591  Value zero = arith::ConstantOp::create(rewriter, loc, outType,
592  rewriter.getFloatAttr(outType, 0.0));
593  int64_t opOutWidth = 32 / outType.getIntOrFloatBitWidth();
594  VectorType truncScaleResultType = VectorType::get(opOutWidth, outType);
595 
596  if (!outVecType) {
597  Type inVecType = VectorType::get(1, inType);
598  Value inCast = vector::BroadcastOp::create(rewriter, loc, inVecType, in);
599  // TODO: replace this with non-packed ScaledTruncOp
600  Value scaleTrunc = amdgpu::PackedScaledTruncOp::create(
601  rewriter, loc, truncScaleResultType, inCast, scale, 0,
602  /*existing=*/nullptr);
603  scaleTrunc =
604  rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
605  return success();
606  }
607 
608  VectorType inVecType = cast<VectorType>(in.getType());
609  Value origScale = getOriginalVectorValue(op.getScale());
610  VectorType origScaleVecType = dyn_cast<VectorType>(origScale.getType());
611 
612  ArrayRef<int64_t> inShape = inVecType.getShape();
613  SmallVector<int64_t> scaleShape;
614  if (origScaleVecType)
615  llvm::append_range(scaleShape, origScaleVecType.getShape());
616 
617  scaleShape.insert(scaleShape.end(), inShape.size() - scaleShape.size(), 1);
618 
619  auto maybeRatio = computeShapeRatio(inShape, scaleShape);
620  assert(maybeRatio &&
621  "failed to derive block size from broadcast or splat operation");
622 
623  SmallVector<int64_t> ratio =
624  maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));
625 
626  int64_t blockSize = computeProduct(ratio);
627 
628  Value result =
629  rewriter.createOrFold<vector::BroadcastOp>(loc, outVecType, zero);
630 
631  for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
632  SmallVector<int64_t> strides(offsets.size(), 1);
633  Value block = vector::ExtractStridedSliceOp::create(
634  rewriter, loc, in, offsets, ratio, strides);
635  VectorType block1DType = VectorType::get(blockSize, inType);
636  Value block1D =
637  vector::ShapeCastOp::create(rewriter, loc, block1DType, block);
638  Value uniformScale =
639  vector::ExtractOp::create(rewriter, loc, scale, offsets);
640 
641  VectorType blockResultType = VectorType::get(blockSize, outType);
642  Value blockResult =
643  rewriter.createOrFold<vector::BroadcastOp>(loc, blockResultType, zero);
644 
645  for (int64_t i = 0, outSliceWidth = std::min(opOutWidth, blockSize - i);
646  i < blockSize; i += outSliceWidth,
647  outSliceWidth = std::min(opOutWidth, blockSize - i)) {
648  Value scaleTrunc;
649  // Case where <= 2 elements are being truncated.
650  if (outSliceWidth <= opInWidth) {
651  Value slice = vector::ExtractStridedSliceOp::create(
652  rewriter, loc, block1D, i, outSliceWidth, 1);
653  // TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
654  scaleTrunc = amdgpu::PackedScaledTruncOp::create(
655  rewriter, loc, truncScaleResultType, slice, uniformScale, 0,
656  /*existing=*/nullptr);
657  } else {
658  scaleTrunc = vector::BroadcastOp::create(rewriter, loc,
659  truncScaleResultType, zero);
660  for (int64_t j = 0,
661  inSliceWidth = std::min(opInWidth, outSliceWidth - j);
662  j < outSliceWidth; j += opInWidth,
663  inSliceWidth = std::min(opInWidth, outSliceWidth - j)) {
664  Value slice = vector::ExtractStridedSliceOp::create(
665  rewriter, loc, block1D, i + j, inSliceWidth, 1);
666  scaleTrunc = amdgpu::PackedScaledTruncOp::create(
667  rewriter, loc, truncScaleResultType, slice, uniformScale,
668  j / opInWidth, scaleTrunc);
669  }
670  }
671  if (outSliceWidth != opOutWidth) {
672  scaleTrunc = vector::ExtractStridedSliceOp::create(
673  rewriter, loc, scaleTrunc, 0, outSliceWidth, 1);
674  }
675  blockResult = vector::InsertStridedSliceOp::create(
676  rewriter, loc, scaleTrunc, blockResult, i, 1);
677  }
678 
679  VectorType resultType = VectorType::get(ratio, outType);
680  Value cast =
681  vector::ShapeCastOp::create(rewriter, loc, resultType, blockResult);
682  result = vector::InsertStridedSliceOp::create(rewriter, loc, cast, result,
683  offsets, strides);
684  }
685 
686  rewriter.replaceOp(op, result);
687 
688  return success();
689 }
690 
692  RewritePatternSet &patterns, bool convertFP8Arithmetic,
693  bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset,
694  PatternBenefit benefit) {
695 
696  if (convertFP8Arithmetic) {
697  patterns.add<ExtFOnFloat8RewritePattern>(patterns.getContext(), chipset,
698  benefit);
699  patterns.add<TruncFToFloat8RewritePattern>(
700  patterns.getContext(), saturateFP8Truncf, chipset, benefit);
701  }
702  if (allowPackedF16Rtz)
703  patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext(), benefit);
704 
705  if (chipset >= kGfx950) {
706  patterns.add<ScalingExtFRewritePattern>(patterns.getContext(), benefit);
707  patterns.add<ScalingTruncFRewritePattern>(patterns.getContext(), benefit);
708  }
709 }
710 
711 void ArithToAMDGPUConversionPass::runOnOperation() {
712  Operation *op = getOperation();
713  MLIRContext *ctx = &getContext();
715  FailureOr<amdgpu::Chipset> maybeChipset = amdgpu::Chipset::parse(chipset);
716  if (failed(maybeChipset)) {
717  emitError(UnknownLoc::get(ctx), "Invalid chipset name: " + chipset);
718  return signalPassFailure();
719  }
720 
721  bool convertFP8Arithmetic =
722  *maybeChipset == kGfx942 || hasOcpFp8(*maybeChipset);
724  patterns, convertFP8Arithmetic, saturateFP8Truncf, allowPackedF16Rtz,
725  *maybeChipset);
726  if (failed(applyPatternsGreedily(op, std::move(patterns))))
727  return signalPassFailure();
728 }
constexpr Chipset kGfx942
constexpr Chipset kGfx950
static Value getOriginalVectorValue(Value value)
Get the broadcasted / splatted value for a chain of ops.
static Value castF32To(Type desType, Value f32, Location loc, PatternRewriter &rewriter)
static Value castToF32(Value value, Location loc, PatternRewriter &rewriter)
static bool isSupportedF8(Type elementType, Chipset chipset)
static Value clampInput(PatternRewriter &rewriter, Location loc, Type outElemType, Value source)
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
FloatType getF32Type()
Definition: Builders.cpp:42
FloatAttr getFloatAttr(Type type, double value)
Definition: Builders.cpp:253
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:519
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
A range-style iterator that allows for iterating over the offsets of all potential tiles of size tile...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isF16() const
Definition: Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
bool hasOcpFp8(const Chipset &chipset)
Definition: Chipset.h:52
void populateArithToAMDGPUConversionPatterns(RewritePatternSet &patterns, bool convertFP8Arithmetic, bool saturateFP8Truncf, bool allowPackedF16Rtz, amdgpu::Chipset chipset, PatternBenefit benefit=1)
Add patterns for rewriting arith.extf and arith.truncf on FP8 types to wrappers around AMDGPU–specifi...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
Value createScalarOrSplatConstant(OpBuilder &builder, Location loc, Type type, const APInt &value)
Create a constant of type type at location loc whose value is value (an APInt or APFloat whose type m...
Definition: Utils.cpp:270
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::optional< SmallVector< int64_t > > computeShapeRatio(ArrayRef< int64_t > shape, ArrayRef< int64_t > subShape)
Return the multi-dimensional integral ratio of subShape to the trailing dimensions of shape.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319
Represents the amdgpu gfx chipset version, e.g., gfx90a, gfx942, gfx1103.
Definition: Chipset.h:22
static FailureOr< Chipset > parse(StringRef name)
Parses the chipset version string and returns the chipset on success, and failure otherwise.
Definition: Chipset.cpp:14
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.