22#define GEN_PASS_DEF_ARITHTOAPFLOATCONVERSIONPASS
23#include "mlir/Conversion/Passes.h.inc"
39static FailureOr<FuncOp>
42 auto i32Type = IntegerType::get(symTable->getContext(), 32);
43 auto i64Type = IntegerType::get(symTable->getContext(), 64);
44 std::string funcName = (llvm::Twine(
"_mlir_apfloat_") + name).str();
46 {i32Type, i64Type, i64Type}, symbolTables);
50template <
typename OpTy>
65 FailureOr<FuncOp> fn =
74 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
77 auto floatTy = cast<FloatType>(resultType);
78 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
79 auto int64Type = rewriter.getI64Type();
80 Value lhsBits = arith::ExtUIOp::create(
81 rewriter, loc, int64Type,
82 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
83 Value rhsBits = arith::ExtUIOp::create(
84 rewriter, loc, int64Type,
85 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
88 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
89 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
90 auto resultOp = func::CallOp::create(rewriter, loc,
92 SymbolRefAttr::get(*fn), params);
95 Value truncatedBits = arith::TruncIOp::create(rewriter, loc, intWType,
96 resultOp->getResult(0));
97 return arith::BitcastOp::create(rewriter, loc, floatTy,
100 rewriter.replaceOp(op, repl);
108template <
typename OpTy>
120 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
121 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
122 FailureOr<FuncOp> fn =
124 {i32Type, i32Type, i64Type});
132 rewriter, loc, op.getOperand(),
Value(), op.getType(),
135 auto inFloatTy = cast<FloatType>(operand1.getType());
136 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
137 Value operandBits = arith::ExtUIOp::create(
138 rewriter, loc, i64Type,
139 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
142 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
143 auto outFloatTy = cast<FloatType>(resultType);
145 getAPFloatSemanticsValue(rewriter, loc, outFloatTy);
146 std::array<Value, 3> params = {inSemValue, outSemValue, operandBits};
147 auto resultOp = func::CallOp::create(rewriter, loc,
149 SymbolRefAttr::get(*fn), params);
152 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
153 Value truncatedBits = arith::TruncIOp::create(
154 rewriter, loc, outIntWType, resultOp->getResult(0));
155 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
158 rewriter.replaceOp(op, repl);
165template <
typename OpTy>
178 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
179 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
180 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
181 FailureOr<FuncOp> fn =
183 {i32Type, i32Type, i1Type, i64Type});
191 rewriter, loc, op.getOperand(),
Value(), op.getType(),
194 auto inFloatTy = cast<FloatType>(operand1.getType());
195 auto inIntWType = rewriter.getIntegerType(inFloatTy.getWidth());
196 Value operandBits = arith::ExtUIOp::create(
197 rewriter, loc, i64Type,
198 arith::BitcastOp::create(rewriter, loc, inIntWType, operand1));
201 Value inSemValue = getAPFloatSemanticsValue(rewriter, loc, inFloatTy);
202 auto outIntTy = cast<IntegerType>(resultType);
203 Value outWidthValue = arith::ConstantOp::create(
204 rewriter, loc, i32Type,
205 rewriter.getIntegerAttr(i32Type, outIntTy.getWidth()));
206 Value isUnsignedValue = arith::ConstantOp::create(
207 rewriter, loc, i1Type,
208 rewriter.getIntegerAttr(i1Type, isUnsigned));
209 SmallVector<Value> params = {inSemValue, outWidthValue,
210 isUnsignedValue, operandBits};
211 auto resultOp = func::CallOp::create(rewriter, loc,
213 SymbolRefAttr::get(*fn), params);
216 return arith::TruncIOp::create(rewriter, loc, outIntTy,
217 resultOp->getResult(0));
219 rewriter.replaceOp(op, repl);
227template <
typename OpTy>
240 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
241 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
242 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
244 rewriter,
symTable,
"_mlir_apfloat_convert_from_int",
245 {i32Type, i32Type, i1Type, i64Type});
253 rewriter, loc, op.getOperand(),
Value(), op.getType(),
256 auto inIntTy = cast<IntegerType>(operand1.getType());
257 Value operandBits = operand1;
258 if (operandBits.getType().getIntOrFloatBitWidth() < 64) {
261 arith::ExtUIOp::create(rewriter, loc, i64Type, operandBits);
264 arith::ExtSIOp::create(rewriter, loc, i64Type, operandBits);
269 auto outFloatTy = cast<FloatType>(resultType);
272 Value inWidthValue = arith::ConstantOp::create(
273 rewriter, loc, i32Type,
275 Value isUnsignedValue = arith::ConstantOp::create(
276 rewriter, loc, i1Type,
279 isUnsignedValue, operandBits};
280 auto resultOp = func::CallOp::create(rewriter, loc,
282 SymbolRefAttr::get(*fn), params);
285 auto outIntWType = rewriter.
getIntegerType(outFloatTy.getWidth());
286 Value truncatedBits = arith::TruncIOp::create(
287 rewriter, loc, outIntWType, resultOp->getResult(0));
288 return arith::BitcastOp::create(rewriter, loc, outFloatTy,
291 rewriter.replaceOp(op, repl);
310 auto i1Type = IntegerType::get(
symTable->getContext(), 1);
311 auto i8Type = IntegerType::get(
symTable->getContext(), 8);
312 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
313 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
314 FailureOr<FuncOp> fn =
316 {i32Type, i64Type, i64Type},
nullptr, i8Type);
324 rewriter, loc, op.getLhs(), op.getRhs(), op.getType(),
327 auto floatTy = cast<FloatType>(lhs.getType());
328 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
329 Value lhsBits = arith::ExtUIOp::create(
330 rewriter, loc, i64Type,
331 arith::BitcastOp::create(rewriter, loc, intWType, lhs));
332 Value rhsBits = arith::ExtUIOp::create(
333 rewriter, loc, i64Type,
334 arith::BitcastOp::create(rewriter, loc, intWType, rhs));
337 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
338 SmallVector<Value> params = {semValue, lhsBits, rhsBits};
339 Value comparisonResult =
340 func::CallOp::create(rewriter, loc,
TypeRange(i8Type),
341 SymbolRefAttr::get(*fn), params)
346 auto checkResult = [&](llvm::APFloat::cmpResult val) {
347 return arith::CmpIOp::create(
348 rewriter, loc, arith::CmpIPredicate::eq, comparisonResult,
349 arith::ConstantOp::create(
350 rewriter, loc, i8Type,
358 Value first = checkResult(vals.front());
359 if (vals.size() == 1)
361 Value rest = checkResults(vals.drop_front());
362 return arith::OrIOp::create(rewriter, loc, first, rest)
368 switch (op.getPredicate()) {
369 case arith::CmpFPredicate::AlwaysFalse:
371 arith::ConstantOp::create(rewriter, loc, i1Type,
375 case arith::CmpFPredicate::OEQ:
376 result = checkResult(llvm::APFloat::cmpEqual);
378 case arith::CmpFPredicate::OGT:
379 result = checkResult(llvm::APFloat::cmpGreaterThan);
381 case arith::CmpFPredicate::OGE:
383 {llvm::APFloat::cmpGreaterThan, llvm::APFloat::cmpEqual});
385 case arith::CmpFPredicate::OLT:
386 result = checkResult(llvm::APFloat::cmpLessThan);
388 case arith::CmpFPredicate::OLE:
390 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpEqual});
392 case arith::CmpFPredicate::ONE:
395 {llvm::APFloat::cmpLessThan, llvm::APFloat::cmpGreaterThan});
397 case arith::CmpFPredicate::ORD:
399 result = checkResults({llvm::APFloat::cmpLessThan,
400 llvm::APFloat::cmpGreaterThan,
401 llvm::APFloat::cmpEqual});
403 case arith::CmpFPredicate::UEQ:
405 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpEqual});
407 case arith::CmpFPredicate::UGT:
409 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpGreaterThan});
411 case arith::CmpFPredicate::UGE:
412 result = checkResults({llvm::APFloat::cmpUnordered,
413 llvm::APFloat::cmpGreaterThan,
414 llvm::APFloat::cmpEqual});
416 case arith::CmpFPredicate::ULT:
418 {llvm::APFloat::cmpUnordered, llvm::APFloat::cmpLessThan});
420 case arith::CmpFPredicate::ULE:
421 result = checkResults({llvm::APFloat::cmpUnordered,
422 llvm::APFloat::cmpLessThan,
423 llvm::APFloat::cmpEqual});
425 case arith::CmpFPredicate::UNE:
427 result = checkResults({llvm::APFloat::cmpLessThan,
428 llvm::APFloat::cmpGreaterThan,
429 llvm::APFloat::cmpUnordered});
431 case arith::CmpFPredicate::UNO:
432 result = checkResult(llvm::APFloat::cmpUnordered);
434 case arith::CmpFPredicate::AlwaysTrue:
436 arith::ConstantOp::create(rewriter, loc, i1Type,
443 rewriter.replaceOp(op, repl);
461 auto i32Type = IntegerType::get(
symTable->getContext(), 32);
462 auto i64Type = IntegerType::get(
symTable->getContext(), 64);
464 rewriter,
symTable,
"_mlir_apfloat_neg", {i32Type, i64Type});
475 auto floatTy = cast<FloatType>(operand1.getType());
476 auto intWType = rewriter.getIntegerType(floatTy.getWidth());
477 Value operandBits = arith::ExtUIOp::create(
478 rewriter, loc, i64Type,
479 arith::BitcastOp::create(rewriter, loc, intWType, operand1));
482 Value semValue = getAPFloatSemanticsValue(rewriter, loc, floatTy);
483 SmallVector<Value> params = {semValue, operandBits};
485 func::CallOp::create(rewriter, loc,
TypeRange(i64Type),
486 SymbolRefAttr::get(*fn), params)
490 Value truncatedBits =
491 arith::TruncIOp::create(rewriter, loc, intWType, negatedBits);
492 return arith::BitcastOp::create(rewriter, loc, floatTy,
495 rewriter.replaceOp(op, repl);
503struct ArithToAPFloatConversionPass final
504 : impl::ArithToAPFloatConversionPassBase<ArithToAPFloatConversionPass> {
507 void runOnOperation()
override;
510void ArithToAPFloatConversionPass::runOnOperation() {
512 RewritePatternSet
patterns(context);
513 patterns.add<BinaryArithOpToAPFloatConversion<arith::AddFOp>>(context,
"add",
515 patterns.add<BinaryArithOpToAPFloatConversion<arith::SubFOp>>(
516 context,
"subtract", getOperation());
517 patterns.add<BinaryArithOpToAPFloatConversion<arith::MulFOp>>(
518 context,
"multiply", getOperation());
519 patterns.add<BinaryArithOpToAPFloatConversion<arith::DivFOp>>(
520 context,
"divide", getOperation());
521 patterns.add<BinaryArithOpToAPFloatConversion<arith::RemFOp>>(
522 context,
"remainder", getOperation());
523 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinNumFOp>>(
524 context,
"minnum", getOperation());
525 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaxNumFOp>>(
526 context,
"maxnum", getOperation());
527 patterns.add<BinaryArithOpToAPFloatConversion<arith::MinimumFOp>>(
528 context,
"minimum", getOperation());
529 patterns.add<BinaryArithOpToAPFloatConversion<arith::MaximumFOp>>(
530 context,
"maximum", getOperation());
532 .add<FpToFpConversion<arith::ExtFOp>, FpToFpConversion<arith::TruncFOp>,
533 CmpFOpToAPFloatConversion, NegFOpToAPFloatConversion>(
534 context, getOperation());
535 patterns.add<FpToIntConversion<arith::FPToSIOp>>(context, getOperation(),
537 patterns.add<FpToIntConversion<arith::FPToUIOp>>(context, getOperation(),
539 patterns.add<IntToFpConversion<arith::SIToFPOp>>(context, getOperation(),
541 patterns.add<IntToFpConversion<arith::UIToFPOp>>(context, getOperation(),
544 ScopedDiagnosticHandler scopedHandler(context, [&
result](Diagnostic &
diag) {
545 if (
diag.getSeverity() == DiagnosticSeverity::Error) {
554 return signalPassFailure();
static FailureOr< FuncOp > lookupOrCreateBinaryFn(OpBuilder &b, SymbolOpInterface symTable, StringRef name, SymbolTableCollection *symbolTables=nullptr)
Helper function to look up or create the symbol for a runtime library function for a binary arithmeti...
static std::string diag(const llvm::Value &value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class helps build Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a collection of SymbolTables.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< FuncOp > lookupOrCreateFnDecl(OpBuilder &b, SymbolOpInterface symTable, StringRef name, TypeRange paramTypes, SymbolTableCollection *symbolTables=nullptr, Type resultType={})
Helper function to look up or create the symbol for a runtime library function with the given paramet...
Include the generated interface declarations.
LogicalResult checkPreconditions(RewriterBase &rewriter, Operation *op)
Check preconditions for the conversion:
Value getAPFloatSemanticsValue(OpBuilder &b, Location loc, FloatType floatTy)
Value forEachScalarValue(mlir::RewriterBase &rewriter, Location loc, Value operand1, Value operand2, Type resultType, llvm::function_ref< Value(Value, Value, Type)> fn)
Given two operands of vector type and vector result type (with the same shape), call the given functi...
const FrozenRewritePatternSet & patterns
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
BinaryArithOpToAPFloatConversion(MLIRContext *context, const char *APFloatName, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(arith::CmpFOp op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
CmpFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
FpToFpConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
FpToIntConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
SymbolOpInterface symTable
LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override
IntToFpConversion(MLIRContext *context, SymbolOpInterface symTable, bool isUnsigned, PatternBenefit benefit=1)
LogicalResult matchAndRewrite(arith::NegFOp op, PatternRewriter &rewriter) const override
NegFOpToAPFloatConversion(MLIRContext *context, SymbolOpInterface symTable, PatternBenefit benefit=1)
SymbolOpInterface symTable
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})