MLIR 23.0.0git
ExpandOps.cpp
Go to the documentation of this file.
1//===- ExpandPatterns.cpp - Code to expand various math operations. -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements expansion of various math operations.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
20
21using namespace mlir;
22
23namespace mlir::math {
24#define GEN_PASS_DEF_MATHEXPANDOPSPASS
25#include "mlir/Dialect/Math/Transforms/Passes.h.inc"
26} // namespace mlir::math
27
28/// Create a float constant.
29static Value createFloatConst(Location loc, Type type, APFloat value,
30 OpBuilder &b) {
31 bool losesInfo = false;
32 auto eltType = getElementTypeOrSelf(type);
33 // Convert double to the given `FloatType` with round-to-nearest-ties-to-even.
34 value.convert(cast<FloatType>(eltType).getFloatSemantics(),
35 APFloat::rmNearestTiesToEven, &losesInfo);
36 auto attr = b.getFloatAttr(eltType, value);
37 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
38 return arith::ConstantOp::create(b, loc,
39 DenseElementsAttr::get(shapedTy, attr));
40 }
42 return arith::ConstantOp::create(b, loc, attr);
45static Value createFloatConst(Location loc, Type type, double value,
46 OpBuilder &b) {
47 return createFloatConst(loc, type, APFloat(value), b);
49
50/// Create an integer constant.
51static Value createIntConst(Location loc, Type type, int64_t value,
52 OpBuilder &b) {
53 auto attr = b.getIntegerAttr(getElementTypeOrSelf(type), value);
54 if (auto shapedTy = dyn_cast<ShapedType>(type)) {
55 return arith::ConstantOp::create(b, loc,
56 DenseElementsAttr::get(shapedTy, attr));
57 }
58
59 return arith::ConstantOp::create(b, loc, attr);
60}
61
63 Type opType = operand.getType();
64 Type i64Ty = b.getI64Type();
65 if (auto shapedTy = dyn_cast<ShapedType>(opType))
66 i64Ty = shapedTy.clone(i64Ty);
67 Value fixedConvert = arith::FPToSIOp::create(b, i64Ty, operand);
68 Value fpFixedConvert = arith::SIToFPOp::create(b, opType, fixedConvert);
69 // The truncation does not preserve the sign when the truncated
70 // value is -0. So here the sign is copied again.
71 return math::CopySignOp::create(b, fpFixedConvert, operand);
73
74// sinhf(float x) -> (exp(x) - exp(-x)) / 2
75static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter) {
76 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
77 Value operand = op.getOperand();
78 Type opType = operand.getType();
79
80 Value exp = math::ExpOp::create(b, operand);
81 Value neg = arith::NegFOp::create(b, operand);
82 Value nexp = math::ExpOp::create(b, neg);
83 Value sub = arith::SubFOp::create(b, exp, nexp);
84 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
85 Value res = arith::MulFOp::create(b, sub, half);
86 rewriter.replaceOp(op, res);
87 return success();
89
90// coshf(float x) -> (exp(x) + exp(-x)) / 2
91static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter) {
92 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
93 Value operand = op.getOperand();
94 Type opType = operand.getType();
95
96 Value exp = math::ExpOp::create(b, operand);
97 Value neg = arith::NegFOp::create(b, operand);
98 Value nexp = math::ExpOp::create(b, neg);
99 Value add = arith::AddFOp::create(b, exp, nexp);
100 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
101 Value res = arith::MulFOp::create(b, add, half);
102 rewriter.replaceOp(op, res);
103 return success();
104}
105
106/// Expands tanh op into
107/// 1-exp^{-2x} / 1+exp^{-2x}
108/// To avoid overflow we exploit the reflection symmetry `tanh(-x) = -tanh(x)`.
109/// We compute a "signs" value which is -1 if input is negative and +1 if input
110/// is positive. Then multiply the input by this value, guaranteeing that the
111/// result is positive, which also guarantees `exp^{-2x * sign(x)}` is in (0,
112/// 1]. Expand the computation on the input `x * sign(x)`, then multiply the
113/// result by `sign(x)` to retain sign of the real result.
114static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter) {
115 auto floatType = op.getOperand().getType();
116 Location loc = op.getLoc();
117 Value zero = createFloatConst(loc, floatType, 0.0, rewriter);
118 Value one = createFloatConst(loc, floatType, 1.0, rewriter);
119 Value negTwo = createFloatConst(loc, floatType, -2.0, rewriter);
120
121 // Compute sign(x) = cast<float_type>(x < 0) * (-2) + 1
122 Value isNegative = arith::CmpFOp::create(
123 rewriter, loc, arith::CmpFPredicate::OLT, op.getOperand(), zero);
124 Value isNegativeFloat =
125 arith::UIToFPOp::create(rewriter, loc, floatType, isNegative);
126 Value isNegativeTimesNegTwo =
127 arith::MulFOp::create(rewriter, loc, isNegativeFloat, negTwo);
128 Value sign = arith::AddFOp::create(rewriter, loc, isNegativeTimesNegTwo, one);
129
130 // Normalize input to positive value: y = sign(x) * x
131 Value positiveX = arith::MulFOp::create(rewriter, loc, sign, op.getOperand());
132
133 // Decompose on normalized input
134 Value negDoubledX = arith::MulFOp::create(rewriter, loc, negTwo, positiveX);
135 Value exp2x = math::ExpOp::create(rewriter, loc, negDoubledX);
136 Value dividend = arith::SubFOp::create(rewriter, loc, one, exp2x);
137 Value divisor = arith::AddFOp::create(rewriter, loc, one, exp2x);
138 Value positiveRes = arith::DivFOp::create(rewriter, loc, dividend, divisor);
139
140 // Multiply result by sign(x) to retain signs from negative inputs
141 rewriter.replaceOpWithNewOp<arith::MulFOp>(op, sign, positiveRes);
142
143 return success();
144}
145
146// Converts math.tan to math.sin, math.cos, and arith.divf.
147static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter) {
148 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
149 Value operand = op.getOperand();
150 Type type = operand.getType();
151 Value sin = math::SinOp::create(b, type, operand);
152 Value cos = math::CosOp::create(b, type, operand);
153 Value div = arith::DivFOp::create(b, type, sin, cos);
154 rewriter.replaceOp(op, div);
155 return success();
156}
157
158// asinh(float x) -> log(x + sqrt(x**2 + 1))
159static LogicalResult convertAsinhOp(math::AsinhOp op,
160 PatternRewriter &rewriter) {
161 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
162 Value operand = op.getOperand();
163 Type opType = operand.getType();
164
165 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
166 Value fma = math::FmaOp::create(b, operand, operand, one);
167 Value sqrt = math::SqrtOp::create(b, fma);
168 Value add = arith::AddFOp::create(b, operand, sqrt);
169 Value res = math::LogOp::create(b, add);
170 rewriter.replaceOp(op, res);
171 return success();
172}
173
174// acosh(float x) -> log(x + sqrt(x**2 - 1))
175static LogicalResult convertAcoshOp(math::AcoshOp op,
176 PatternRewriter &rewriter) {
177 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
178 Value operand = op.getOperand();
179 Type opType = operand.getType();
180
181 Value negOne = createFloatConst(op->getLoc(), opType, -1.0, rewriter);
182 Value fma = math::FmaOp::create(b, operand, operand, negOne);
183 Value sqrt = math::SqrtOp::create(b, fma);
184 Value add = arith::AddFOp::create(b, operand, sqrt);
185 Value res = math::LogOp::create(b, add);
186 rewriter.replaceOp(op, res);
187 return success();
188}
189
190// atanh(float x) -> log((1 + x) / (1 - x)) / 2
191static LogicalResult convertAtanhOp(math::AtanhOp op,
192 PatternRewriter &rewriter) {
193 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
194 Value operand = op.getOperand();
195 Type opType = operand.getType();
196
197 Value one = createFloatConst(op->getLoc(), opType, 1.0, rewriter);
198 Value add = arith::AddFOp::create(b, operand, one);
199 Value neg = arith::NegFOp::create(b, operand);
200 Value sub = arith::AddFOp::create(b, neg, one);
201 Value div = arith::DivFOp::create(b, add, sub);
202 Value log = math::LogOp::create(b, div);
203 Value half = createFloatConst(op->getLoc(), opType, 0.5, rewriter);
204 Value res = arith::MulFOp::create(b, log, half);
205 rewriter.replaceOp(op, res);
206 return success();
207}
208
209static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter) {
210 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
211 Value operandA = op.getOperand(0);
212 Value operandB = op.getOperand(1);
213 Value operandC = op.getOperand(2);
214 Type type = op.getType();
215 Value mult = arith::MulFOp::create(b, type, operandA, operandB);
216 Value add = arith::AddFOp::create(b, type, mult, operandC);
217 rewriter.replaceOp(op, add);
218 return success();
219}
220
221// Converts a ceilf() function to the following:
222// ceilf(float x) ->
223// y = (float)(int) x
224// if (x > y) then incr = 1 else incr = 0
225// y = y + incr <= replace this op with the ceilf op.
226static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter) {
227 // Creating constants assumes the static shaped type.
228 auto shapedType = dyn_cast<ShapedType>(op.getType());
229 if (shapedType && !shapedType.hasStaticShape())
230 return failure();
231
232 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
233 Value operand = op.getOperand();
234 Type opType = operand.getType();
235 Type operandETy = getElementTypeOrSelf(opType);
236 FloatType floatTy = llvm::dyn_cast<FloatType>(operandETy);
237 const llvm::fltSemantics &semantics = floatTy.getFloatSemantics();
238
239 unsigned bitWidth = floatTy.getWidth();
240 unsigned mantissaWidth = floatTy.getFPMantissaWidth() - 1;
241 const int bias = (&semantics == &APFloat::Float8E8M0FNU())
242 ? -semantics.minExponent
243 : -(semantics.minExponent - 1);
244 bool hasNegativeZeroNaNEncoding =
245 (semantics.nanEncoding == llvm::fltNanEncoding::NegativeZero);
246
247 Type iTy = rewriter.getIntegerType(bitWidth);
248 if (auto shapedTy = dyn_cast<ShapedType>(opType))
249 iTy = shapedTy.clone(iTy);
250
251 // For IEEE-like floating-point formats with an unbiased exponent ≥
252 // `mantissaWidth` falls into one of these categories:
253 // - a large finite value (|x| ≥ 2^mantissaWidth), where all representable
254 // numbers are already integral, or
255 // - a special value (NaN or ±Inf), which also satisfies this exponent
256 // condition.
257 // For all such cases, `ceilf(x)` is defined to return `x` directly.
258 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
259 Value cMask = createIntConst(
260 op->getLoc(), iTy, static_cast<int64_t>((1ull << (bitWidth - 1)) - 1), b);
261 Value unsignedBits = arith::AndIOp::create(b, operandBitcast, cMask);
262 Value cThreshold = createIntConst(
263 op->getLoc(), iTy,
264 static_cast<int64_t>((uint64_t(bias + mantissaWidth)) << mantissaWidth),
265 b);
266 Value isLargeExp = arith::CmpIOp::create(b, arith::CmpIPredicate::uge,
267 unsignedBits, cThreshold);
268 Value isSpecialValOrLargeVal = isLargeExp;
269
270 // In FNUZ-suffixed floating point, NaN is represented by a sign bit of 1 and
271 // all 0s in the exponent and mantissa, therefore requires an explicit check.
272 if (hasNegativeZeroNaNEncoding) {
273 Value cNegZeroBits = createIntConst(
274 op->getLoc(), iTy, static_cast<int64_t>(1ull << (bitWidth - 1)), b);
275 Value isNegZeroEncoding = arith::CmpIOp::create(
276 b, arith::CmpIPredicate::eq, operandBitcast, cNegZeroBits);
277 isSpecialValOrLargeVal =
278 arith::OrIOp::create(b, isLargeExp, isNegZeroEncoding);
279 }
280
281 Value fpFixedConvert = createTruncatedFPValue(operand, b);
282
283 // Creating constants for later use.
284 Value zero = createFloatConst(op->getLoc(), opType, 0.00, rewriter);
285 Value one = createFloatConst(op->getLoc(), opType, 1.00, rewriter);
286
287 Value gtCheck = arith::CmpFOp::create(b, arith::CmpFPredicate::OGT, operand,
288 fpFixedConvert);
289 Value incrValue =
290 arith::SelectOp::create(b, op->getLoc(), gtCheck, one, zero);
291
292 Value add = arith::AddFOp::create(b, opType, fpFixedConvert, incrValue);
293 Value ret = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand, add);
294 rewriter.replaceOp(op, ret);
295 return success();
296}
297
298// Convert `math.fpowi` to a series of `arith.mulf` operations.
299// If the power is negative, we divide one by the result.
300// If both the base and power are zero, the result is 1.
301// In the case of non constant power, we convert the operation to `math.powf`.
302static LogicalResult convertFPowIOp(math::FPowIOp op,
303 PatternRewriter &rewriter) {
304 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
305 Value base = op.getOperand(0);
306 Value power = op.getOperand(1);
307 Type baseType = base.getType();
308
309 auto convertFPowItoPowf = [&]() -> LogicalResult {
310 Value castPowerToFp =
311 arith::SIToFPOp::create(rewriter, op.getLoc(), baseType, power);
312 Value res = math::PowFOp::create(rewriter, op.getLoc(), baseType, base,
313 castPowerToFp);
314 rewriter.replaceOp(op, res);
315 return success();
316 };
317
318 Attribute cstAttr;
319 if (!matchPattern(power, m_Constant(&cstAttr)))
320 return convertFPowItoPowf();
321
322 APInt value;
323 if (!matchPattern(cstAttr, m_ConstantInt(&value)))
324 return convertFPowItoPowf();
325
326 int64_t powerInt = value.getSExtValue();
327 bool isNegative = powerInt < 0;
328 int64_t absPower = std::abs(powerInt);
329 Value one = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
330 Value res = createFloatConst(op->getLoc(), baseType, 1.00, rewriter);
331
332 while (absPower > 0) {
333 if (absPower & 1)
334 res = arith::MulFOp::create(b, baseType, base, res);
335 absPower >>= 1;
336 base = arith::MulFOp::create(b, baseType, base, base);
337 }
338
339 // Make sure not to introduce UB in case of negative power.
340 if (isNegative) {
341 auto &sem = dyn_cast<mlir::FloatType>(getElementTypeOrSelf(baseType))
342 .getFloatSemantics();
343 Value zero =
344 createFloatConst(op->getLoc(), baseType,
345 APFloat::getZero(sem, /*Negative=*/false), rewriter);
346 Value negZero =
347 createFloatConst(op->getLoc(), baseType,
348 APFloat::getZero(sem, /*Negative=*/true), rewriter);
349 Value posInfinity =
350 createFloatConst(op->getLoc(), baseType,
351 APFloat::getInf(sem, /*Negative=*/false), rewriter);
352 Value negInfinity =
353 createFloatConst(op->getLoc(), baseType,
354 APFloat::getInf(sem, /*Negative=*/true), rewriter);
355 Value zeroEqCheck =
356 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, zero);
357 Value negZeroEqCheck =
358 arith::CmpFOp::create(b, arith::CmpFPredicate::OEQ, res, negZero);
359 res = arith::DivFOp::create(b, baseType, one, res);
360 res =
361 arith::SelectOp::create(b, op->getLoc(), zeroEqCheck, posInfinity, res);
362 res = arith::SelectOp::create(b, op->getLoc(), negZeroEqCheck, negInfinity,
363 res);
364 }
365
366 rewriter.replaceOp(op, res);
367 return success();
368}
369
370// Converts Powf(float a, float b) (meaning a^b) to exp^(b * ln(a))
371// Some special cases where b is constant are handled separately:
372// when b == 0, or |b| == 0.5, 1.0, or 2.0.
373static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter) {
374 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
375 Value operandA = op.getOperand(0);
376 Value operandB = op.getOperand(1);
377 auto typeA = operandA.getType();
378 auto typeB = operandB.getType();
379
380 auto &sem =
381 cast<mlir::FloatType>(getElementTypeOrSelf(typeB)).getFloatSemantics();
382 APFloat valueB(sem);
383 auto mulf = [&](Value x, Value y) -> Value {
384 return arith::MulFOp::create(b, x, y);
385 };
386 if (matchPattern(operandB, m_ConstantFloat(&valueB))) {
387 if (valueB.isZero()) {
388 // a^0 -> 1
389 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
390 rewriter.replaceOp(op, one);
391 return success();
392 }
393 if (valueB.isExactlyValue(1.0)) {
394 // a^1 -> a
395 rewriter.replaceOp(op, operandA);
396 return success();
397 }
398 if (valueB.isExactlyValue(-1.0)) {
399 // a^(-1) -> 1 / a
400 Value one = createFloatConst(op->getLoc(), typeA, 1.0, rewriter);
401 Value div = arith::DivFOp::create(b, one, operandA);
402 rewriter.replaceOp(op, div);
403 return success();
404 }
405 if (valueB.isExactlyValue(0.5)) {
406 // a^(1/2) -> sqrt(a)
407 Value sqrt = math::SqrtOp::create(b, operandA);
408 rewriter.replaceOp(op, sqrt);
409 return success();
410 }
411 if (valueB.isExactlyValue(-0.5)) {
412 // a^(-1/2) -> 1 / sqrt(a)
413 Value rsqrt = math::RsqrtOp::create(b, operandA);
414 rewriter.replaceOp(op, rsqrt);
415 return success();
416 }
417 if (valueB.isExactlyValue(2.0)) {
418 // a^2 -> a * a
419 rewriter.replaceOp(op, mulf(operandA, operandA));
420 return success();
421 }
422 if (valueB.isExactlyValue(-2.0)) {
423 // a^(-2) -> 1 / (a * a)
424 Value one =
425 createFloatConst(op->getLoc(), operandA.getType(), 1.0, rewriter);
426 Value div = arith::DivFOp::create(b, one, mulf(operandA, operandA));
427 rewriter.replaceOp(op, div);
428 return success();
429 }
430 if (valueB.isExactlyValue(3.0)) {
431 rewriter.replaceOp(op, mulf(mulf(operandA, operandA), operandA));
432 return success();
433 }
434 }
435
436 Value logA = math::LogOp::create(b, operandA);
437 Value mult = arith::MulFOp::create(b, operandB, logA);
438 Value expResult = math::ExpOp::create(b, mult);
439 rewriter.replaceOp(op, expResult);
440 return success();
441}
442
443// exp2f(float x) -> exp(x * ln(2))
444// Proof: Let's say 2^x = y
445// ln(2^x) = ln(y)
446// x * ln(2) = ln(y) => e ^(x*ln(2)) = y
447static LogicalResult convertExp2fOp(math::Exp2Op op,
448 PatternRewriter &rewriter) {
449 ImplicitLocOpBuilder b(op->getLoc(), rewriter);
450 Value operand = op.getOperand();
451 Type opType = operand.getType();
452 Value ln2 = createFloatConst(op->getLoc(), opType, llvm::numbers::ln2, b);
453 Value mult = arith::MulFOp::create(b, opType, operand, ln2);
454 Value exp = math::ExpOp::create(b, op->getLoc(), mult);
455 rewriter.replaceOp(op, exp);
456 return success();
457}
458
459static LogicalResult convertRoundOp(math::RoundOp op,
460 PatternRewriter &rewriter) {
461 Location loc = op.getLoc();
462 ImplicitLocOpBuilder b(loc, rewriter);
463 Value operand = op.getOperand();
464 Type opType = operand.getType();
465 Type opEType = getElementTypeOrSelf(opType);
466
467 if (!opEType.isF32()) {
468 return rewriter.notifyMatchFailure(op, "not a round of f32.");
469 }
470
471 Type i32Ty = b.getI32Type();
472 if (auto shapedTy = dyn_cast<ShapedType>(opType))
473 i32Ty = shapedTy.clone(i32Ty);
474
475 Value half = createFloatConst(loc, opType, 0.5, b);
476 Value c23 = createIntConst(loc, i32Ty, 23, b);
477 Value c127 = createIntConst(loc, i32Ty, 127, b);
478 Value expMask = createIntConst(loc, i32Ty, (1 << 8) - 1, b);
479
480 Value incrValue = math::CopySignOp::create(b, half, operand);
481 Value add = arith::AddFOp::create(b, opType, operand, incrValue);
482 Value fpFixedConvert = createTruncatedFPValue(add, b);
483
484 // There are three cases where adding 0.5 to the value and truncating by
485 // converting to an i64 does not result in the correct behavior:
486 //
487 // 1. Special values: +-inf and +-nan
488 // Casting these special values to i64 has undefined behavior. To identify
489 // these values, we use the fact that these values are the only float
490 // values with the maximum possible biased exponent.
491 //
492 // 2. Large values: 2^23 <= |x| <= INT_64_MAX
493 // Adding 0.5 to a float larger than or equal to 2^23 results in precision
494 // errors that sometimes round the value up and sometimes round the value
495 // down. For example:
496 // 8388608.0 + 0.5 = 8388608.0
497 // 8388609.0 + 0.5 = 8388610.0
498 //
499 // 3. Very large values: |x| > INT_64_MAX
500 // Casting to i64 a value greater than the max i64 value will overflow the
501 // i64 leading to wrong outputs.
502 //
503 // All three cases satisfy the property `biasedExp >= 23`.
504 Value operandBitcast = arith::BitcastOp::create(b, i32Ty, operand);
505 Value operandExp = arith::AndIOp::create(
506 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
507 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
508 Value isSpecialValOrLargeVal = arith::CmpIOp::create(
509 b, arith::CmpIPredicate::sge, operandBiasedExp, c23);
510
511 Value result = arith::SelectOp::create(b, isSpecialValOrLargeVal, operand,
512 fpFixedConvert);
513 rewriter.replaceOp(op, result);
514 return success();
515}
516
517// Converts math.ctlz to scf and arith operations. This is done
518// by performing a binary search on the bits.
519static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op,
520 PatternRewriter &rewriter) {
521 auto operand = op.getOperand();
522 auto operandTy = operand.getType();
523 auto eTy = getElementTypeOrSelf(operandTy);
524 Location loc = op.getLoc();
525
526 // Only expand for integer or float element types (index has no fixed bitwidth).
527 if (!eTy.isIntOrFloat()) {
528 return rewriter.notifyMatchFailure(op, "ctlz expansion only supports int or float types");
529 }
530
531 int32_t bitwidth = eTy.getIntOrFloatBitWidth();
532 if (bitwidth > 64)
533 return failure();
534
535 uint64_t allbits = -1;
536 if (bitwidth < 64) {
537 allbits = allbits >> (64 - bitwidth);
538 }
539
540 Value x = operand;
541 Value count = createIntConst(loc, operandTy, 0, rewriter);
542 for (int32_t bw = bitwidth; bw > 1; bw = bw / 2) {
543 auto half = bw / 2;
544 auto bits = createIntConst(loc, operandTy, half, rewriter);
545 auto mask = createIntConst(loc, operandTy, allbits >> half, rewriter);
546
547 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::ule,
548 x, mask);
549 Value add = arith::AddIOp::create(rewriter, loc, count, bits);
550 Value shift = arith::ShLIOp::create(rewriter, loc, x, bits);
551
552 x = arith::SelectOp::create(rewriter, loc, pred, shift, x);
553 count = arith::SelectOp::create(rewriter, loc, pred, add, count);
554 }
555
556 Value zero = createIntConst(loc, operandTy, 0, rewriter);
557 Value pred = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
558 operand, zero);
559
560 Value bwval = createIntConst(loc, operandTy, bitwidth, rewriter);
561 Value sel = arith::SelectOp::create(rewriter, loc, pred, bwval, count);
562 rewriter.replaceOp(op, sel);
563 return success();
564}
565
566// Convert `math.roundeven` into `math.round` + arith ops
567static LogicalResult convertRoundEvenOp(math::RoundEvenOp op,
568 PatternRewriter &rewriter) {
569 Location loc = op.getLoc();
570 ImplicitLocOpBuilder b(loc, rewriter);
571 auto operand = op.getOperand();
572 Type operandTy = operand.getType();
573 Type resultTy = op.getType();
574 Type operandETy = getElementTypeOrSelf(operandTy);
575 Type resultETy = getElementTypeOrSelf(resultTy);
576
577 if (!isa<FloatType>(operandETy) || !isa<FloatType>(resultETy)) {
578 return rewriter.notifyMatchFailure(op, "not a roundeven of f16 or f32.");
579 }
580
581 Type fTy = operandTy;
582 Type iTy = rewriter.getIntegerType(operandETy.getIntOrFloatBitWidth());
583 if (auto shapedTy = dyn_cast<ShapedType>(fTy)) {
584 iTy = shapedTy.clone(iTy);
585 }
586
587 unsigned bitWidth = operandETy.getIntOrFloatBitWidth();
588 // The width returned by getFPMantissaWidth includes the integer bit.
589 unsigned mantissaWidth =
590 llvm::cast<FloatType>(operandETy).getFPMantissaWidth() - 1;
591 unsigned exponentWidth = bitWidth - mantissaWidth - 1;
592
593 // The names of the variables correspond to f32.
594 // f64: 1 bit sign | 11 bits exponent | 52 bits mantissa.
595 // f32: 1 bit sign | 8 bits exponent | 23 bits mantissa.
596 // f16: 1 bit sign | 5 bits exponent | 10 bits mantissa.
597 Value c1Float = createFloatConst(loc, fTy, 1.0, b);
598 Value c0 = createIntConst(loc, iTy, 0, b);
599 Value c1 = createIntConst(loc, iTy, 1, b);
600 Value cNeg1 = createIntConst(loc, iTy, -1, b);
601 Value c23 = createIntConst(loc, iTy, mantissaWidth, b);
602 Value c31 = createIntConst(loc, iTy, bitWidth - 1, b);
603 Value c127 = createIntConst(loc, iTy, (1ull << (exponentWidth - 1)) - 1, b);
604 Value c2To22 = createIntConst(loc, iTy, 1ull << (mantissaWidth - 1), b);
605 Value c23Mask = createIntConst(loc, iTy, (1ull << mantissaWidth) - 1, b);
606 Value expMask = createIntConst(loc, iTy, (1ull << exponentWidth) - 1, b);
607
608 Value operandBitcast = arith::BitcastOp::create(b, iTy, operand);
609 Value round = math::RoundOp::create(b, operand);
610 Value roundBitcast = arith::BitcastOp::create(b, iTy, round);
611
612 // Get biased exponents for operand and round(operand)
613 Value operandExp = arith::AndIOp::create(
614 b, arith::ShRUIOp::create(b, operandBitcast, c23), expMask);
615 Value operandBiasedExp = arith::SubIOp::create(b, operandExp, c127);
616 Value roundExp = arith::AndIOp::create(
617 b, arith::ShRUIOp::create(b, roundBitcast, c23), expMask);
618 Value roundBiasedExp = arith::SubIOp::create(b, roundExp, c127);
619
620 auto safeShiftRight = [&](Value x, Value shift) -> Value {
621 // Clamp shift to valid range [0, bitwidth - 1] to avoid undefined behavior
622 Value clampedShift = arith::MaxSIOp::create(b, shift, c0);
623 clampedShift = arith::MinSIOp::create(b, clampedShift, c31);
624 return arith::ShRUIOp::create(b, x, clampedShift);
625 };
626
627 auto maskMantissa = [&](Value mantissa,
628 Value mantissaMaskRightShift) -> Value {
629 Value shiftedMantissaMask = safeShiftRight(c23Mask, mantissaMaskRightShift);
630 return arith::AndIOp::create(b, mantissa, shiftedMantissaMask);
631 };
632
633 // A whole number `x`, such that `|x| != 1`, is even if the mantissa, ignoring
634 // the leftmost `clamp(biasedExp - 1, 0, 23)` bits, is zero. Large numbers
635 // with `biasedExp > 23` (numbers where there is not enough precision to store
636 // decimals) are always even, and they satisfy the even condition trivially
637 // since the mantissa without all its bits is zero. The even condition
638 // is also true for +-0, since they have `biasedExp = -127` and the entire
639 // mantissa is zero. The case of +-1 has to be handled separately. Here
640 // we identify these values by noting that +-1 are the only whole numbers with
641 // `biasedExp == 0`.
642 //
643 // The special values +-inf and +-nan also satisfy the same property that
644 // whole non-unit even numbers satisfy. In particular, the special values have
645 // `biasedExp > 23`, so they get treated as large numbers with no room for
646 // decimals, which are always even.
647 Value roundBiasedExpEq0 =
648 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, roundBiasedExp, c0);
649 Value roundBiasedExpMinus1 = arith::SubIOp::create(b, roundBiasedExp, c1);
650 Value roundMaskedMantissa = maskMantissa(roundBitcast, roundBiasedExpMinus1);
651 Value roundIsNotEvenOrSpecialVal = arith::CmpIOp::create(
652 b, arith::CmpIPredicate::ne, roundMaskedMantissa, c0);
653 roundIsNotEvenOrSpecialVal =
654 arith::OrIOp::create(b, roundIsNotEvenOrSpecialVal, roundBiasedExpEq0);
655
656 // A value `x` with `0 <= biasedExp < 23`, is halfway between two consecutive
657 // integers if the bit at index `biasedExp` starting from the left in the
658 // mantissa is 1 and all the bits to the right are zero. Values with
659 // `biasedExp >= 23` don't have decimals, so they are never halfway. The
660 // values +-0.5 are the only halfway values that have `biasedExp == -1 < 0`,
661 // so these are handled separately. In particular, if `biasedExp == -1`, the
662 // value is halfway if the entire mantissa is zero.
663 Value operandBiasedExpEqNeg1 = arith::CmpIOp::create(
664 b, arith::CmpIPredicate::eq, operandBiasedExp, cNeg1);
665 Value expectedOperandMaskedMantissa = arith::SelectOp::create(
666 b, operandBiasedExpEqNeg1, c0, safeShiftRight(c2To22, operandBiasedExp));
667 Value operandMaskedMantissa = maskMantissa(operandBitcast, operandBiasedExp);
668 Value operandIsHalfway =
669 arith::CmpIOp::create(b, arith::CmpIPredicate::eq, operandMaskedMantissa,
670 expectedOperandMaskedMantissa);
671 // Ensure `biasedExp` is in the valid range for half values.
672 Value operandBiasedExpGeNeg1 = arith::CmpIOp::create(
673 b, arith::CmpIPredicate::sge, operandBiasedExp, cNeg1);
674 Value operandBiasedExpLt23 = arith::CmpIOp::create(
675 b, arith::CmpIPredicate::slt, operandBiasedExp, c23);
676 operandIsHalfway =
677 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpLt23);
678 operandIsHalfway =
679 arith::AndIOp::create(b, operandIsHalfway, operandBiasedExpGeNeg1);
680
681 // Adjust rounded operand with `round(operand) - sign(operand)` to correct the
682 // case where `round` rounded in the opposite direction of `roundeven`.
683 Value sign = math::CopySignOp::create(b, c1Float, operand);
684 Value roundShifted = arith::SubFOp::create(b, round, sign);
685 // If the rounded value is even or a special value, we default to the behavior
686 // of `math.round`.
687 Value needsShift =
688 arith::AndIOp::create(b, roundIsNotEvenOrSpecialVal, operandIsHalfway);
689 Value result = arith::SelectOp::create(b, needsShift, roundShifted, round);
690 // The `x - sign` adjustment does not preserve the sign when we are adjusting
691 // the value -1 to -0. So here the sign is copied again to ensure that -0.5 is
692 // rounded to -0.0.
693 result = math::CopySignOp::create(b, result, operand);
694 rewriter.replaceOp(op, result);
695 return success();
696}
697
698// Convert `math.rsqrt` into `arith.divf` + `math.sqrt`
699static LogicalResult convertRsqrtOp(math::RsqrtOp op,
700 PatternRewriter &rewriter) {
701
702 auto operand = op.getOperand();
703 auto operandTy = operand.getType();
704 // Operand type must be shatic shaped type to create const float.
705 auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
706 if (shapedOperandType && !shapedOperandType.hasStaticShape())
707 return failure();
708
709 auto eTy = getElementTypeOrSelf(operandTy);
710 if (!isa<FloatType>(eTy))
711 return failure();
712
713 Location loc = op->getLoc();
714 auto constOneFloat = createFloatConst(loc, operandTy, 1.0, rewriter);
715 auto sqrtOp = math::SqrtOp::create(rewriter, loc, operand);
716 rewriter.replaceOpWithNewOp<arith::DivFOp>(op, constOneFloat, sqrtOp);
717 return success();
718}
719
720// Convert `math.clampf` into `arith.minimumf` + `arith.maximumf`
721static LogicalResult convertClampfOp(math::ClampFOp op,
722 PatternRewriter &rewriter) {
723 auto minOp = arith::MinimumFOp::create(rewriter, op.getLoc(), op.getValue(),
724 op.getMax(), op.getFastmath());
725 rewriter.replaceOpWithNewOp<arith::MaximumFOp>(op, minOp, op.getMin(),
726 op.getFastmath());
727 return success();
728}
729
731 ArrayRef<StringRef> opMnemonics) {
732 auto filter = [&](StringRef name) {
733 // This should be a static assert and `consume_front` take a twine, but none
734 // is currently possible. TODO: augment `StringRef::consume_front` and make
735 // `getDialectNamespace` use `std::string_view`.
736 assert("math" == MathDialect::getDialectNamespace());
737 name.consume_front("math.");
738 return opMnemonics.empty() || (llvm::count(opMnemonics, name) > 0);
739 };
740 if (filter(CountLeadingZerosOp::getOperationName()))
741 patterns.add(convertCtlzOp);
742 if (filter(SinhOp::getOperationName()))
743 patterns.add(convertSinhOp);
744 if (filter(CoshOp::getOperationName()))
745 patterns.add(convertCoshOp);
746 if (filter(TanOp::getOperationName()))
747 patterns.add(convertTanOp);
748 if (filter(TanhOp::getOperationName()))
749 patterns.add(convertTanhOp);
750 if (filter(AsinhOp::getOperationName()))
751 patterns.add(convertAsinhOp);
752 if (filter(AcoshOp::getOperationName()))
753 patterns.add(convertAcoshOp);
754 if (filter(AtanhOp::getOperationName()))
755 patterns.add(convertAtanhOp);
756 if (filter(FmaOp::getOperationName()))
757 patterns.add(convertFmaFOp);
758 if (filter(CeilOp::getOperationName()))
759 patterns.add(convertCeilOp);
760 if (filter(Exp2Op::getOperationName()))
761 patterns.add(convertExp2fOp);
762 if (filter(PowFOp::getOperationName()))
763 patterns.add(convertPowfOp);
764 if (filter(FPowIOp::getOperationName()))
765 patterns.add(convertFPowIOp);
766 if (filter(RoundOp::getOperationName()))
767 patterns.add(convertRoundOp);
768 if (filter(RoundEvenOp::getOperationName()))
769 patterns.add(convertRoundEvenOp);
770 if (filter(RsqrtOp::getOperationName()))
771 patterns.add(convertRsqrtOp);
772 if (filter(ClampFOp::getOperationName()))
773 patterns.add(convertClampfOp);
774}
775
776//===----------------------------------------------------------------------===//
777// MathExpandOpsPass pass
778//===----------------------------------------------------------------------===//
779namespace {
780struct MathExpandOpsPass final
781 : math::impl::MathExpandOpsPassBase<MathExpandOpsPass> {
782 using MathExpandOpsPassBase::MathExpandOpsPassBase;
783
784 void runOnOperation() override {
785 RewritePatternSet patterns(&getContext());
786 SmallVector<StringRef> mnemonics =
787 llvm::to_vector_of<StringRef>(opMnemonics);
788 math::populateExpansionPatterns(patterns, mnemonics);
789 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
790 return signalPassFailure();
791 }
792};
793} // namespace
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static LogicalResult convertRsqrtOp(math::RsqrtOp op, PatternRewriter &rewriter)
static Value createTruncatedFPValue(Value operand, ImplicitLocOpBuilder &b)
Definition ExpandOps.cpp:62
static LogicalResult convertFPowIOp(math::FPowIOp op, PatternRewriter &rewriter)
static LogicalResult convertPowfOp(math::PowFOp op, PatternRewriter &rewriter)
static LogicalResult convertClampfOp(math::ClampFOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundOp(math::RoundOp op, PatternRewriter &rewriter)
static LogicalResult convertTanOp(math::TanOp op, PatternRewriter &rewriter)
static LogicalResult convertCtlzOp(math::CountLeadingZerosOp op, PatternRewriter &rewriter)
static LogicalResult convertFmaFOp(math::FmaOp op, PatternRewriter &rewriter)
static LogicalResult convertAtanhOp(math::AtanhOp op, PatternRewriter &rewriter)
static LogicalResult convertCeilOp(math::CeilOp op, PatternRewriter &rewriter)
static LogicalResult convertRoundEvenOp(math::RoundEvenOp op, PatternRewriter &rewriter)
static LogicalResult convertCoshOp(math::CoshOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:91
static LogicalResult convertSinhOp(math::SinhOp op, PatternRewriter &rewriter)
Definition ExpandOps.cpp:75
static Value createFloatConst(Location loc, Type type, APFloat value, OpBuilder &b)
Create a float constant.
Definition ExpandOps.cpp:29
static LogicalResult convertAsinhOp(math::AsinhOp op, PatternRewriter &rewriter)
static Value createIntConst(Location loc, Type type, int64_t value, OpBuilder &b)
Create an integer constant.
Definition ExpandOps.cpp:51
static LogicalResult convertTanhOp(math::TanhOp op, PatternRewriter &rewriter)
Expands tanh op into 1-exp^{-2x} / 1+exp^{-2x} To avoid overflow we exploit the reflection symmetry t...
static LogicalResult convertAcoshOp(math::AcoshOp op, PatternRewriter &rewriter)
static LogicalResult convertExp2fOp(math::Exp2Op op, PatternRewriter &rewriter)
#define add(a, b)
#define div(a, b)
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ImplicitLocOpBuilder maintains a 'current location', allowing use of the create<> method without spec...
Definition Builders.h:632
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
This class helps build Operations.
Definition Builders.h:209
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isF32() const
Definition Types.cpp:40
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
void populateExpansionPatterns(RewritePatternSet &patterns, ArrayRef< StringRef > opMnemonics={})
Adds patterns to expand math operations into other more fundamental operations.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
detail::constant_float_value_binder m_ConstantFloat(FloatAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor float (splat) and writes the float value to bind_va...
Definition Matchers.h:520