MLIR  19.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
19 
20 using namespace mlir;
21 
22 //===----------------------------------------------------------------------===//
23 // Conversion Patterns
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 class ArithConstantOpConversionPattern
28  : public OpConversionPattern<arith::ConstantOp> {
29 public:
31 
33  matchAndRewrite(arith::ConstantOp arithConst,
34  arith::ConstantOp::Adaptor adaptor,
35  ConversionPatternRewriter &rewriter) const override {
36  rewriter.replaceOpWithNewOp<emitc::ConstantOp>(
37  arithConst, arithConst.getType(), adaptor.getValue());
38  return success();
39  }
40 };
41 
42 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
43 public:
45 
46  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
47  switch (pred) {
48  case arith::CmpIPredicate::eq:
49  case arith::CmpIPredicate::ne:
50  case arith::CmpIPredicate::slt:
51  case arith::CmpIPredicate::sle:
52  case arith::CmpIPredicate::sgt:
53  case arith::CmpIPredicate::sge:
54  return false;
55  case arith::CmpIPredicate::ult:
56  case arith::CmpIPredicate::ule:
57  case arith::CmpIPredicate::ugt:
58  case arith::CmpIPredicate::uge:
59  return true;
60  }
61  llvm_unreachable("unknown cmpi predicate kind");
62  }
63 
64  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
65  switch (pred) {
66  case arith::CmpIPredicate::eq:
67  return emitc::CmpPredicate::eq;
68  case arith::CmpIPredicate::ne:
69  return emitc::CmpPredicate::ne;
70  case arith::CmpIPredicate::slt:
71  case arith::CmpIPredicate::ult:
72  return emitc::CmpPredicate::lt;
73  case arith::CmpIPredicate::sle:
74  case arith::CmpIPredicate::ule:
75  return emitc::CmpPredicate::le;
76  case arith::CmpIPredicate::sgt:
77  case arith::CmpIPredicate::ugt:
78  return emitc::CmpPredicate::gt;
79  case arith::CmpIPredicate::sge:
80  case arith::CmpIPredicate::uge:
81  return emitc::CmpPredicate::ge;
82  }
83  llvm_unreachable("unknown cmpi predicate kind");
84  }
85 
87  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override {
89 
90  Type type = adaptor.getLhs().getType();
91  if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
92  return rewriter.notifyMatchFailure(op, "expected integer or index type");
93  }
94 
95  bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
96  emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
97  Type arithmeticType = type;
98  if (type.isUnsignedInteger() != needsUnsigned) {
99  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
100  /*isSigned=*/!needsUnsigned);
101  }
102  Value lhs = adaptor.getLhs();
103  Value rhs = adaptor.getRhs();
104  if (arithmeticType != type) {
105  lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
106  lhs);
107  rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
108  rhs);
109  }
110  rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
111  return success();
112  }
113 };
114 
115 template <typename ArithOp, typename EmitCOp>
116 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
117 public:
119 
121  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
122  ConversionPatternRewriter &rewriter) const override {
123 
124  rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, arithOp.getType(),
125  adaptor.getOperands());
126 
127  return success();
128  }
129 };
130 
131 template <typename ArithOp, typename EmitCOp>
132 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
133 public:
135 
137  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
138  ConversionPatternRewriter &rewriter) const override {
139 
140  Type type = this->getTypeConverter()->convertType(op.getType());
141  if (!isa_and_nonnull<IntegerType, IndexType>(type)) {
142  return rewriter.notifyMatchFailure(op, "expected integer type");
143  }
144 
145  if (type.isInteger(1)) {
146  // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
147  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
148  }
149 
150  Value lhs = adaptor.getLhs();
151  Value rhs = adaptor.getRhs();
152  Type arithmeticType = type;
153  if ((type.isSignlessInteger() || type.isSignedInteger()) &&
154  !bitEnumContainsAll(op.getOverflowFlags(),
155  arith::IntegerOverflowFlags::nsw)) {
156  // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
157  // we compute in unsigned integers to avoid UB.
158  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
159  /*isSigned=*/false);
160  }
161  if (arithmeticType != type) {
162  lhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
163  lhs);
164  rhs = rewriter.template create<emitc::CastOp>(op.getLoc(), arithmeticType,
165  rhs);
166  }
167 
168  Value result = rewriter.template create<EmitCOp>(op.getLoc(),
169  arithmeticType, lhs, rhs);
170 
171  if (arithmeticType != type) {
172  result =
173  rewriter.template create<emitc::CastOp>(op.getLoc(), type, result);
174  }
175  rewriter.replaceOp(op, result);
176  return success();
177  }
178 };
179 
180 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
181 public:
183 
185  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
186  ConversionPatternRewriter &rewriter) const override {
187 
188  Type dstType = getTypeConverter()->convertType(selectOp.getType());
189  if (!dstType)
190  return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
191 
192  if (!adaptor.getCondition().getType().isInteger(1))
193  return rewriter.notifyMatchFailure(
194  selectOp,
195  "can only be converted if condition is a scalar of type i1");
196 
197  rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
198  adaptor.getOperands());
199 
200  return success();
201  }
202 };
203 
204 // Floating-point to integer conversions.
205 template <typename CastOp>
206 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
207 public:
208  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
209  : OpConversionPattern<CastOp>(typeConverter, context) {}
210 
212  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
213  ConversionPatternRewriter &rewriter) const override {
214 
215  Type operandType = adaptor.getIn().getType();
216  if (!emitc::isSupportedFloatType(operandType))
217  return rewriter.notifyMatchFailure(castOp,
218  "unsupported cast source type");
219 
220  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
221  if (!dstType)
222  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
223 
224  // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
225  // truncated to 0, whereas a boolean conversion would return true.
226  if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
227  return rewriter.notifyMatchFailure(castOp,
228  "unsupported cast destination type");
229 
230  // Convert to unsigned if it's the "ui" variant
231  // Signless is interpreted as signed, so no need to cast for "si"
232  Type actualResultType = dstType;
233  if (isa<arith::FPToUIOp>(castOp)) {
234  actualResultType =
235  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
236  /*isSigned=*/false);
237  }
238 
239  Value result = rewriter.create<emitc::CastOp>(
240  castOp.getLoc(), actualResultType, adaptor.getOperands());
241 
242  if (isa<arith::FPToUIOp>(castOp)) {
243  result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
244  }
245  rewriter.replaceOp(castOp, result);
246 
247  return success();
248  }
249 };
250 
251 // Integer to floating-point conversions.
252 template <typename CastOp>
253 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
254 public:
255  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
256  : OpConversionPattern<CastOp>(typeConverter, context) {}
257 
259  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
260  ConversionPatternRewriter &rewriter) const override {
261  // Vectors in particular are not supported
262  Type operandType = adaptor.getIn().getType();
263  if (!emitc::isSupportedIntegerType(operandType))
264  return rewriter.notifyMatchFailure(castOp,
265  "unsupported cast source type");
266 
267  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
268  if (!dstType)
269  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
270 
271  if (!emitc::isSupportedFloatType(dstType))
272  return rewriter.notifyMatchFailure(castOp,
273  "unsupported cast destination type");
274 
275  // Convert to unsigned if it's the "ui" variant
276  // Signless is interpreted as signed, so no need to cast for "si"
277  Type actualOperandType = operandType;
278  if (isa<arith::UIToFPOp>(castOp)) {
279  actualOperandType =
280  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
281  /*isSigned=*/false);
282  }
283  Value fpCastOperand = adaptor.getIn();
284  if (actualOperandType != operandType) {
285  fpCastOperand = rewriter.template create<emitc::CastOp>(
286  castOp.getLoc(), actualOperandType, fpCastOperand);
287  }
288  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
289 
290  return success();
291  }
292 };
293 
294 } // namespace
295 
296 //===----------------------------------------------------------------------===//
297 // Pattern population
298 //===----------------------------------------------------------------------===//
299 
301  RewritePatternSet &patterns) {
302  MLIRContext *ctx = patterns.getContext();
303 
304  // clang-format off
305  patterns.add<
306  ArithConstantOpConversionPattern,
307  ArithOpConversion<arith::AddFOp, emitc::AddOp>,
308  ArithOpConversion<arith::DivFOp, emitc::DivOp>,
309  ArithOpConversion<arith::MulFOp, emitc::MulOp>,
310  ArithOpConversion<arith::SubFOp, emitc::SubOp>,
311  IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
312  IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
313  IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
314  CmpIOpConversion,
315  SelectOpConversion,
316  ItoFCastOpConversion<arith::SIToFPOp>,
317  ItoFCastOpConversion<arith::UIToFPOp>,
318  FtoICastOpConversion<arith::FPToSIOp>,
319  FtoICastOpConversion<arith::FPToUIOp>
320  >(typeConverter, ctx);
321  // clang-format on
322 }
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:87
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
PatternRewriter hook for replacing an operation.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:79
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:67
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:91
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:125
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:116
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:95
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26