MLIR 22.0.0git
ArithToAPFloat.cpp
Go to the documentation of this file.
1//===- ArithToAPFloat.cpp - Arithmetic to APFloat Conversion --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "Utils.h"
10
18#include "mlir/IR/Verifier.h"
20
21namespace mlir {
22#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
24} // namespace mlir
25
26using namespace mlir;
27using namespace mlir::func;
28
29/// Helper function to look up or create the symbol for a runtime library
30/// function for a binary arithmetic operation.
31///
32/// Parameter 1: APFloat semantics
33/// Parameter 2: Left-hand side operand
34/// Parameter 3: Right-hand side operand
35///
36/// This function will return a failure if the function is found but has an
37/// unexpected signature.
38///
39static FailureOr<FuncOp>
40lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name,
41 SymbolTableCollection *symbolTables = nullptr) {
42 auto i32Type = IntegerType::get(symTable->getContext(), 32);
43 auto i64Type = IntegerType::get(symTable->getContext(), 64);
44 std::string funcName = (llvm::Twine("_mlir_apfloat_") + name).str();
45 return lookupOrCreateFnDecl(b, symTable, funcName,
46 {i32Type, i64Type, i64Type}, symbolTables);
47}
48
49/// Given two operands of vector type and vector result type (with the same
50/// shape), call the given function for each pair of scalar operands and
51/// package the result into a vector. If the given operands and result type are
52/// not vectors, call the function directly. The second operand is optional.
53template <typename Fn, typename... Values>
55 Value operand1, Value operand2, Type resultType,
56 Fn fn) {
57 auto vecTy1 = dyn_cast<VectorType>(operand1.getType());
58 if (operand2) {
59 // Sanity check: Operand types must match.
60 assert(vecTy1 == dyn_cast<VectorType>(operand2.getType()) &&
61 "expected same vector types");
62 }
63 if (!vecTy1) {
64 // Not a vector. Call the function directly.
65 return fn(operand1, operand2, resultType);
66 }
67
68 // Prepare scalar operands.
69 ResultRange sclars1 =
70 vector::ToElementsOp::create(rewriter, loc, operand1)->getResults();
71 SmallVector<Value> scalars2;
72 if (!operand2) {
73 // No second operand. Create a vector of empty values.
74 scalars2.assign(vecTy1.getNumElements(), Value());
75 } else {
76 llvm::append_range(
77 scalars2,
78 vector::ToElementsOp::create(rewriter, loc, operand2)->getResults());
79 }
80
81 // Call the function for each pair of scalar operands.
82 auto resultVecType = cast<VectorType>(resultType);
83 SmallVector<Value> results;
84 for (auto [scalar1, scalar2] : llvm::zip_equal(sclars1, scalars2)) {
85 Value result = fn(scalar1, scalar2, resultVecType.getElementType());
86 results.push_back(result);
87 }
88
89 // Package the results into a vector.
90 return vector::FromElementsOp::create(
91 rewriter, loc,
92 vecTy1.cloneWith(/*shape=*/std::nullopt, results.front().getType()),
93 results);
94}
95
96/// Check preconditions for the conversion:
97/// 1. All operands / results must be integers or floats (or vectors thereof).
98/// 2. The bitwidth of the operands / results must be <= 64.
99static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op) {
100 for (Value value : llvm::concat<Value>(op->getOperands(), op->getResults())) {
101 Type type = value.getType();
102 if (auto vecTy = dyn_cast<VectorType>(type)) {
103 type = vecTy.getElementType();
104 }
105 if (!type.isIntOrFloat()) {
106 return rewriter.notifyMatchFailure(
107 op, "only integers and floats (or vectors thereof) are supported");
108 }
109 if (type.getIntOrFloatBitWidth() > 64)
110 return rewriter.notifyMatchFailure(op,
111 "bitwidth > 64 bits is not supported");
112 }
113 return success();
114}
115
116/// Rewrite a binary arithmetic operation to an APFloat function call.
117template <typename OpTy>
120 const char *APFloatName,
121 SymbolOpInterface symTable,
122 PatternBenefit benefit = 1)
123 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
125
126 LogicalResult matchAndRewrite(OpTy op,
127 PatternRewriter &rewriter) const override {
128 if (failed(checkPreconditions(rewriter, op)))
129 return failure();
130
131 // Get APFloat function from runtime library.
132 FailureOr<FuncOp> fn =
134 if (failed(fn))
135 return fn;
136
137 // Scalarize and convert to APFloat runtime calls.
138 Location loc = op.getLoc();
139 rewriter.setInsertionPoint(op);
141 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
142 [&](Value lhs, Value rhs, Type resultType) {
143 // Cast operands to 64-bit integers.
144 auto floatTy = cast<FloatType>(resultType);
145 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
146 auto int64Type = rewriter.getI64Type();
147 Value lhsBits = arith::ExtUIOp::create(
148 rewriter, loc, int64Type,
149 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
150 Value rhsBits = arith::ExtUIOp::create(
151 rewriter, loc, int64Type,
152 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
153
154 // Call APFloat function.
155 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
156 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
157 auto resultOp = func::CallOp::create(rewriter, loc,
158 TypeRange(rewriter.getI64Type()),
159 SymbolRefAttr::get(*fn), params);
160
161 // Truncate result to the original width.
162 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
163 resultOp->getResult(0));
164 return arith::BitcastOp::create(rewriter, loc, floatTy,
165 truncatedBits);
166 });
167 rewriter.replaceOp(op, repl);
168 return success();
169 }
170
171 SymbolOpInterface symTable;
172 const char *APFloatName;
173};
174
175template <typename OpTy>
176struct FpToFpConversion final : OpRewritePattern<OpTy> {
177 FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
178 PatternBenefit benefit = 1)
179 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable) {}
180
181 LogicalResult matchAndRewrite(OpTy op,
182 PatternRewriter &rewriter) const override {
183 if (failed(checkPreconditions(rewriter, op)))
184 return failure();
185
186 // Get APFloat function from runtime library.
187 auto i32Type = IntegerType::get(symTable->getContext(), 32);
188 auto i64Type = IntegerType::get(symTable->getContext(), 64);
189 FailureOr<FuncOp> fn =
190 lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert",
191 {i32Type, i32Type, i64Type});
192 if (failed(fn))
193 return fn;
194
195 // Scalarize and convert to APFloat runtime calls.
196 Location loc = op.getLoc();
197 rewriter.setInsertionPoint(op);
199 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
200 [&](Value operand1, Value operand2, Type resultType) {
201 // Cast operands to 64-bit integers.
202 auto inFloatTy = cast<FloatType>(operand1.getType());
203 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
204 Value operandBits = arith::ExtUIOp::create(
205 rewriter, loc, i64Type,
206 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
207
208 // Call APFloat function.
209 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
210 auto outFloatTy = cast<FloatType>(resultType);
211 Value outSemValue =
212 getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
213 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
214 auto resultOp = func::CallOp::create(rewriter, loc,
215 TypeRange(rewriter.getI64Type()),
216 SymbolRefAttr::get(*fn), params);
217
218 // Truncate result to the original width.
219 auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
220 Value truncatedBits = arith::TruncIOp::create(
221 rewriter, loc, outIntWType, resultOp->getResult(0));
222 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
223 truncatedBits);
224 });
225 rewriter.replaceOp(op, repl);
226 return success();
227 }
228
229 SymbolOpInterface symTable;
230};
231
232template <typename OpTy>
234 FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable,
235 bool isUnsigned, PatternBenefit benefit = 1)
236 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
238
239 LogicalResult matchAndRewrite(OpTy op,
240 PatternRewriter &rewriter) const override {
241 if (failed(checkPreconditions(rewriter, op)))
242 return failure();
243
244 // Get APFloat function from runtime library.
245 auto i1Type = IntegerType::get(symTable->getContext(), 1);
246 auto i32Type = IntegerType::get(symTable->getContext(), 32);
247 auto i64Type = IntegerType::get(symTable->getContext(), 64);
248 FailureOr<FuncOp> fn =
249 lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_convert_to_int",
250 {i32Type, i32Type, i1Type, i64Type});
251 if (failed(fn))
252 return fn;
253
254 // Scalarize and convert to APFloat runtime calls.
255 Location loc = op.getLoc();
256 rewriter.setInsertionPoint(op);
258 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
259 [&](Value operand1, Value operand2, Type resultType) {
260 // Cast operands to 64-bit integers.
261 auto inFloatTy = cast<FloatType>(operand1.getType());
262 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
263 Value operandBits = arith::ExtUIOp::create(
264 rewriter, loc, i64Type,
265 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
266
267 // Call APFloat function.
268 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
269 auto outIntTy = cast<IntegerType>(resultType);
270 Value outWidthValue = arith::ConstantOp::create(
271 rewriter, loc, i32Type,
272 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
273 Value isUnsignedValue = arith::ConstantOp::create(
274 rewriter, loc, i1Type,
275 rewriter.getIntegerAttr(i1Type, isUnsigned));
276 SmallVector<Value> params = {inSemValue, outWidthValue,
277 isUnsignedValue, operandBits};
278 auto resultOp = func::CallOp::create(rewriter, loc,
279 TypeRange(rewriter.getI64Type()),
280 SymbolRefAttr::get(*fn), params);
281
282 // Truncate result to the original width.
283 return arith::TruncIOp::create(rewriter, loc, outIntTy,
284 resultOp->getResult(0));
285 });
286 rewriter.replaceOp(op, repl);
287 return success();
288 }
289
290 SymbolOpInterface symTable;
292};
293
294template <typename OpTy>
296 IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable,
297 bool isUnsigned, PatternBenefit benefit = 1)
298 : OpRewritePattern<OpTy>(context, benefit), symTable(symTable),
300
301 LogicalResult matchAndRewrite(OpTy op,
302 PatternRewriter &rewriter) const override {
303 if (failed(checkPreconditions(rewriter, op)))
304 return failure();
305
306 // Get APFloat function from runtime library.
307 auto i1Type = IntegerType::get(symTable->getContext(), 1);
308 auto i32Type = IntegerType::get(symTable->getContext(), 32);
309 auto i64Type = IntegerType::get(symTable->getContext(), 64);
310 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
311 rewriter, symTable, "_mlir_apfloat_convert_from_int",
312 {i32Type, i32Type, i1Type, i64Type});
313 if (failed(fn))
314 return fn;
315
316 // Scalarize and convert to APFloat runtime calls.
317 Location loc = op.getLoc();
318 rewriter.setInsertionPoint(op);
320 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
321 [&](Value operand1, Value operand2, Type resultType) {
322 // Cast operands to 64-bit integers.
323 auto inIntTy = cast<IntegerType>(operand1.getType());
324 Value operandBits = operand1;
325 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
326 if (isUnsigned) {
327 operandBits =
328 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
329 } else {
330 operandBits =
331 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
332 }
333 }
334
335 // Call APFloat function.
336 auto outFloatTy = cast<FloatType>(resultType);
337 Value outSemValue =
338 getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
339 Value inWidthValue = arith::ConstantOp::create(
340 rewriter, loc, i32Type,
341 rewriter.getIntegerAttr(i32Type, inIntTy.getWidth()));
342 Value isUnsignedValue = arith::ConstantOp::create(
343 rewriter, loc, i1Type,
344 rewriter.getIntegerAttr(i1Type, isUnsigned));
345 SmallVector<Value> params = {outSemValue, inWidthValue,
346 isUnsignedValue, operandBits};
347 auto resultOp = func::CallOp::create(rewriter, loc,
348 TypeRange(rewriter.getI64Type()),
349 SymbolRefAttr::get(*fn), params);
350
351 // Truncate result to the original width.
352 auto outIntWType = rewriter.getIntegerType(outFloatTy.getWidth());
353 Value truncatedBits = arith::TruncIOp::create(
354 rewriter, loc, outIntWType, resultOp->getResult(0));
355 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
356 truncatedBits);
357 });
358 rewriter.replaceOp(op, repl);
359 return success();
360 }
361
362 SymbolOpInterface symTable;
364};
365
366struct CmpFOpToAPFloatConversion final : OpRewritePattern<arith::CmpFOp> {
367 CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
368 PatternBenefit benefit = 1)
369 : OpRewritePattern<arith::CmpFOp>(context, benefit), symTable(symTable) {}
370
371 LogicalResult matchAndRewrite(arith::CmpFOp op,
372 PatternRewriter &rewriter) const override {
373 if (failed(checkPreconditions(rewriter, op)))
374 return failure();
375
376 // Get APFloat function from runtime library.
377 auto i1Type = IntegerType::get(symTable->getContext(), 1);
378 auto i8Type = IntegerType::get(symTable->getContext(), 8);
379 auto i32Type = IntegerType::get(symTable->getContext(), 32);
380 auto i64Type = IntegerType::get(symTable->getContext(), 64);
381 FailureOr<FuncOp> fn =
382 lookupOrCreateFnDecl(rewriter, symTable, "_mlir_apfloat_compare",
383 {i32Type, i64Type, i64Type}, nullptr, i8Type);
384 if (failed(fn))
385 return fn;
386
387 // Scalarize and convert to APFloat runtime calls.
388 Location loc = op.getLoc();
389 rewriter.setInsertionPoint(op);
391 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
392 [&](Value lhs, Value rhs, Type resultType) {
393 // Cast operands to 64-bit integers.
394 auto floatTy = cast<FloatType>(lhs.getType());
395 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
396 Value lhsBits = arith::ExtUIOp::create(
397 rewriter, loc, i64Type,
398 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
399 Value rhsBits = arith::ExtUIOp::create(
400 rewriter, loc, i64Type,
401 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
402
403 // Call APFloat function.
404 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
405 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
406 Value comparisonResult =
407 func::CallOp::create(rewriter, loc, TypeRange(i8Type),
408 SymbolRefAttr::get(*fn), params)
409 ->getResult(0);
410
411 // Generate an i1 SSA value that is "true" if the comparison result
412 // matches the given `val`.
413 auto checkResult = [&](llvm::APFloat::cmpResult val) {
414 return arith::CmpIOp::create(
415 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
416 arith::ConstantOp::create(
417 rewriter, loc, i8Type,
418 rewriter.getIntegerAttr(i8Type, static_cast<int8_t>(val)))
419 .getResult());
420 };
421 // Generate an i1 SSA value that is "true" if the comparison result
422 // matches any of the given `vals`.
424 checkResults = [&](ArrayRef<llvm::APFloat::cmpResult> vals) {
425 Value first = checkResult(vals.front());
426 if (vals.size() == 1)
427 return first;
428 Value rest = checkResults(vals.drop_front());
429 return arith::OrIOp::create(rewriter, loc, first, rest)
430 .getResult();
431 };
432
433 // This switch-case statement was taken from arith::applyCmpPredicate.
435 switch (op.getPredicate()) {
436 case arith::CmpFPredicate::AlwaysFalse:
437 result =
438 arith::ConstantOp::create(rewriter, loc, i1Type,
439 rewriter.getIntegerAttr(i1Type, 0))
440 .getResult();
441 break;
442 case arith::CmpFPredicate::OEQ:
443 result = checkResult(llvm::APFloat::cmpEqual);
444 break;
445 case arith::CmpFPredicate::OGT:
446 result = checkResult(llvm::APFloat::cmpGreaterThan);
447 break;
448 case arith::CmpFPredicate::OGE:
449 result = checkResults(
450 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
451 break;
452 case arith::CmpFPredicate::OLT:
453 result = checkResult(llvm::APFloat::cmpLessThan);
454 break;
455 case arith::CmpFPredicate::OLE:
456 result = checkResults(
457 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
458 break;
459 case arith::CmpFPredicate::ONE:
460 // Not cmpUnordered and not cmpUnordered.
461 result = checkResults(
462 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
463 break;
464 case arith::CmpFPredicate::ORD:
465 // Not cmpUnordered.
466 result = checkResults({llvm::APFloat::cmpLessThan,
467 llvm::APFloat::cmpGreaterThan,
468 llvm::APFloat::cmpEqual});
469 break;
470 case arith::CmpFPredicate::UEQ:
471 result = checkResults(
472 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
473 break;
474 case arith::CmpFPredicate::UGT:
475 result = checkResults(
476 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
477 break;
478 case arith::CmpFPredicate::UGE:
479 result = checkResults({llvm::APFloat::cmpUnordered,
480 llvm::APFloat::cmpGreaterThan,
481 llvm::APFloat::cmpEqual});
482 break;
483 case arith::CmpFPredicate::ULT:
484 result = checkResults(
485 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
486 break;
487 case arith::CmpFPredicate::ULE:
488 result = checkResults({llvm::APFloat::cmpUnordered,
489 llvm::APFloat::cmpLessThan,
490 llvm::APFloat::cmpEqual});
491 break;
492 case arith::CmpFPredicate::UNE:
493 // Not cmpEqual.
494 result = checkResults({llvm::APFloat::cmpLessThan,
495 llvm::APFloat::cmpGreaterThan,
496 llvm::APFloat::cmpUnordered});
497 break;
498 case arith::CmpFPredicate::UNO:
499 result = checkResult(llvm::APFloat::cmpUnordered);
500 break;
501 case arith::CmpFPredicate::AlwaysTrue:
502 result =
503 arith::ConstantOp::create(rewriter, loc, i1Type,
504 rewriter.getIntegerAttr(i1Type, 1))
505 .getResult();
506 break;
507 }
508 return result;
509 });
510 rewriter.replaceOp(op, repl);
511 return success();
512 }
513
514 SymbolOpInterface symTable;
515};
516
517struct NegFOpToAPFloatConversion final : OpRewritePattern<arith::NegFOp> {
518 NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable,
519 PatternBenefit benefit = 1)
520 : OpRewritePattern<arith::NegFOp>(context, benefit), symTable(symTable) {}
521
522 LogicalResult matchAndRewrite(arith::NegFOp op,
523 PatternRewriter &rewriter) const override {
524 if (failed(checkPreconditions(rewriter, op)))
525 return failure();
526
527 // Get APFloat function from runtime library.
528 auto i32Type = IntegerType::get(symTable->getContext(), 32);
529 auto i64Type = IntegerType::get(symTable->getContext(), 64);
530 FailureOr<FuncOp> fn = lookupOrCreateFnDecl(
531 rewriter, symTable, "_mlir_apfloat_neg", {i32Type, i64Type});
532 if (failed(fn))
533 return fn;
534
535 // Scalarize and convert to APFloat runtime calls.
536 Location loc = op.getLoc();
537 rewriter.setInsertionPoint(op);
539 rewriter, loc, op.getOperand(), /*operand2=*/Value(), op.getType(),
540 [&](Value operand1, Value operand2, Type resultType) {
541 // Cast operands to 64-bit integers.
542 auto floatTy = cast<FloatType>(operand1.getType());
543 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
544 Value operandBits = arith::ExtUIOp::create(
545 rewriter, loc, i64Type,
546 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
547
548 // Call APFloat function.
549 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
550 SmallVector<Value> params = {semValue, operandBits};
551 Value negatedBits =
552 func::CallOp::create(rewriter, loc, TypeRange(i64Type),
553 SymbolRefAttr::get(*fn), params)
554 ->getResult(0);
555
556 // Truncate result to the original width.
557 Value truncatedBits =
558 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
559 return arith::BitcastOp::create(rewriter, loc, floatTy,
560 truncatedBits);
561 });
562 rewriter.replaceOp(op, repl);
563 return success();
564 }
565
566 SymbolOpInterface symTable;
567};
568
569namespace {
570struct ArithToAPFloatConversionPass final
571 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
572 using Base::Base;
573
574 void runOnOperation() override;
575};
576
577void ArithToAPFloatConversionPass::runOnOperation() {
578 MLIRContext *context = &getContext();
579 RewritePatternSet patterns(context);
580 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context, "add",
581 getOperation());
582 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
583 context, "subtract", getOperation());
584 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
585 context, "multiply", getOperation());
586 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
587 context, "divide", getOperation());
588 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
589 context, "remainder", getOperation());
590 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
591 context, "minnum", getOperation());
592 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
593 context, "maxnum", getOperation());
594 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
595 context, "minimum", getOperation());
596 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
597 context, "maximum", getOperation());
599 .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
600 CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
601 context, getOperation());
602 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
603 /*isUnsigned=*/false);
604 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
605 /*isUnsigned=*/true);
606 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
607 /*isUnsigned=*/false);
608 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
609 /*isUnsigned=*/true);
610 LogicalResult result = success();
611 ScopedDiagnosticHandler scopedHandler(context, [&result](Diagnostic &diag) {
612 if (diag.getSeverity() == DiagnosticSeverity::Error) {
613 result = failure();
614 }
615 // NB: if you don't return failure, no other diag handlers will fire (see
616 // mlir/lib/IR/Diagnostics.cpp:DiagnosticEngineImpl::emit).
617 return failure();
618 });
619 walkAndApplyPatterns(getOperation(), std::move(patterns));
620 if (failed(result))
621 return signalPassFailure();
622}
623} // namespace
return success()
static Value forEachScalarValue(RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, Fn fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
static LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition Builders.cpp:228
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class implements the result iterators for the Operation class.
Definition ValueRange.h:247
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Definition Utils.cpp:302
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Definition Utils.cpp:17
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})