MLIR  22.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/BuiltinTypes.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Implement the interface to convert Arith to EmitC.
28 struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
30 
31  /// Hook for derived dialect interface to provide conversion patterns
32  /// and mark dialect legal for the conversion target.
33  void populateConvertToEmitCConversionPatterns(
34  ConversionTarget &target, TypeConverter &typeConverter,
35  RewritePatternSet &patterns) const final {
36  populateArithToEmitCPatterns(typeConverter, patterns);
37  }
38 };
39 } // namespace
40 
42  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43  dialect->addInterfaces<ArithToEmitCDialectInterface>();
44  });
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Conversion Patterns
49 //===----------------------------------------------------------------------===//
50 
51 namespace {
52 class ArithConstantOpConversionPattern
53  : public OpConversionPattern<arith::ConstantOp> {
54 public:
55  using Base::Base;
56 
57  LogicalResult
58  matchAndRewrite(arith::ConstantOp arithConst,
59  arith::ConstantOp::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
61  Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62  if (!newTy)
63  return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64  rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65  adaptor.getValue());
66  return success();
67  }
68 };
69 
70 /// Get the signed or unsigned type corresponding to \p ty.
71 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72  if (isa<IntegerType>(ty)) {
73  if (ty.isUnsignedInteger() != needsUnsigned) {
74  auto signedness = needsUnsigned
75  ? IntegerType::SignednessSemantics::Unsigned
78  signedness);
79  }
80  } else if (emitc::isPointerWideType(ty)) {
81  if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82  if (needsUnsigned)
83  return emitc::SizeTType::get(ty.getContext());
85  }
86  }
87  return ty;
88 }
89 
90 /// Insert a cast operation to type \p ty if \p val does not have this type.
91 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93 }
94 
95 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96 public:
97  using Base::Base;
98 
99  LogicalResult
100  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101  ConversionPatternRewriter &rewriter) const override {
102 
103  if (!isa<FloatType>(adaptor.getRhs().getType())) {
104  return rewriter.notifyMatchFailure(op.getLoc(),
105  "cmpf currently only supported on "
106  "floats, not tensors/vectors thereof");
107  }
108 
109  bool unordered = false;
110  emitc::CmpPredicate predicate;
111  switch (op.getPredicate()) {
112  case arith::CmpFPredicate::AlwaysFalse: {
113  auto constant =
114  emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
115  rewriter.getBoolAttr(/*value=*/false));
116  rewriter.replaceOp(op, constant);
117  return success();
118  }
119  case arith::CmpFPredicate::OEQ:
120  unordered = false;
121  predicate = emitc::CmpPredicate::eq;
122  break;
123  case arith::CmpFPredicate::OGT:
124  unordered = false;
125  predicate = emitc::CmpPredicate::gt;
126  break;
127  case arith::CmpFPredicate::OGE:
128  unordered = false;
129  predicate = emitc::CmpPredicate::ge;
130  break;
131  case arith::CmpFPredicate::OLT:
132  unordered = false;
133  predicate = emitc::CmpPredicate::lt;
134  break;
135  case arith::CmpFPredicate::OLE:
136  unordered = false;
137  predicate = emitc::CmpPredicate::le;
138  break;
139  case arith::CmpFPredicate::ONE:
140  unordered = false;
141  predicate = emitc::CmpPredicate::ne;
142  break;
143  case arith::CmpFPredicate::ORD: {
144  // ordered, i.e. none of the operands is NaN
145  auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
146  adaptor.getRhs());
147  rewriter.replaceOp(op, cmp);
148  return success();
149  }
150  case arith::CmpFPredicate::UEQ:
151  unordered = true;
152  predicate = emitc::CmpPredicate::eq;
153  break;
154  case arith::CmpFPredicate::UGT:
155  unordered = true;
156  predicate = emitc::CmpPredicate::gt;
157  break;
158  case arith::CmpFPredicate::UGE:
159  unordered = true;
160  predicate = emitc::CmpPredicate::ge;
161  break;
162  case arith::CmpFPredicate::ULT:
163  unordered = true;
164  predicate = emitc::CmpPredicate::lt;
165  break;
166  case arith::CmpFPredicate::ULE:
167  unordered = true;
168  predicate = emitc::CmpPredicate::le;
169  break;
170  case arith::CmpFPredicate::UNE:
171  unordered = true;
172  predicate = emitc::CmpPredicate::ne;
173  break;
174  case arith::CmpFPredicate::UNO: {
175  // unordered, i.e. either operand is nan
176  auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
177  adaptor.getRhs());
178  rewriter.replaceOp(op, cmp);
179  return success();
180  }
181  case arith::CmpFPredicate::AlwaysTrue: {
182  auto constant =
183  emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
184  rewriter.getBoolAttr(/*value=*/true));
185  rewriter.replaceOp(op, constant);
186  return success();
187  }
188  }
189 
190  // Compare the values naively
191  auto cmpResult =
192  emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
193  adaptor.getLhs(), adaptor.getRhs());
194 
195  // Adjust the results for unordered/ordered semantics
196  if (unordered) {
197  auto isUnordered = createCheckIsUnordered(
198  rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
199  rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200  isUnordered, cmpResult);
201  return success();
202  }
203 
204  auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205  adaptor.getLhs(), adaptor.getRhs());
206  rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207  isOrdered, cmpResult);
208  return success();
209  }
210 
211 private:
212  /// Return a value that is true if \p operand is NaN.
213  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214  Value operand) const {
215  // A value is NaN exactly when it compares unequal to itself.
216  return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
217  emitc::CmpPredicate::ne, operand, operand);
218  }
219 
220  /// Return a value that is true if \p operand is not NaN.
221  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222  Value operand) const {
223  // A value is not NaN exactly when it compares equal to itself.
224  return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
225  emitc::CmpPredicate::eq, operand, operand);
226  }
227 
228  /// Return a value that is true if the operands \p first and \p second are
229  /// unordered (i.e., at least one of them is NaN).
230  Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231  Location loc, Value first, Value second) const {
232  auto firstIsNaN = isNaN(rewriter, loc, first);
233  auto secondIsNaN = isNaN(rewriter, loc, second);
234  return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
235  firstIsNaN, secondIsNaN);
236  }
237 
238  /// Return a value that is true if the operands \p first and \p second are
239  /// both ordered (i.e., none one of them is NaN).
240  Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241  Value first, Value second) const {
242  auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243  auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244  return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
245  firstIsNotNaN, secondIsNotNaN);
246  }
247 };
248 
249 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250 public:
251  using Base::Base;
252 
253  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254  switch (pred) {
255  case arith::CmpIPredicate::eq:
256  case arith::CmpIPredicate::ne:
257  case arith::CmpIPredicate::slt:
258  case arith::CmpIPredicate::sle:
259  case arith::CmpIPredicate::sgt:
260  case arith::CmpIPredicate::sge:
261  return false;
262  case arith::CmpIPredicate::ult:
263  case arith::CmpIPredicate::ule:
264  case arith::CmpIPredicate::ugt:
265  case arith::CmpIPredicate::uge:
266  return true;
267  }
268  llvm_unreachable("unknown cmpi predicate kind");
269  }
270 
271  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272  switch (pred) {
273  case arith::CmpIPredicate::eq:
274  return emitc::CmpPredicate::eq;
275  case arith::CmpIPredicate::ne:
276  return emitc::CmpPredicate::ne;
277  case arith::CmpIPredicate::slt:
278  case arith::CmpIPredicate::ult:
279  return emitc::CmpPredicate::lt;
280  case arith::CmpIPredicate::sle:
281  case arith::CmpIPredicate::ule:
282  return emitc::CmpPredicate::le;
283  case arith::CmpIPredicate::sgt:
284  case arith::CmpIPredicate::ugt:
285  return emitc::CmpPredicate::gt;
286  case arith::CmpIPredicate::sge:
287  case arith::CmpIPredicate::uge:
288  return emitc::CmpPredicate::ge;
289  }
290  llvm_unreachable("unknown cmpi predicate kind");
291  }
292 
293  LogicalResult
294  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296 
297  Type type = adaptor.getLhs().getType();
298  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
299  return rewriter.notifyMatchFailure(
300  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301  }
302 
303  bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304  emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305 
306  Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309 
310  rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311  return success();
312  }
313 };
314 
315 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316 public:
317  using Base::Base;
318 
319  LogicalResult
320  matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321  ConversionPatternRewriter &rewriter) const override {
322 
323  auto adaptedOp = adaptor.getOperand();
324  auto adaptedOpType = adaptedOp.getType();
325 
326  if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327  return rewriter.notifyMatchFailure(
328  op.getLoc(),
329  "negf currently only supports scalar types, not vectors or tensors");
330  }
331 
332  if (!emitc::isSupportedFloatType(adaptedOpType)) {
333  return rewriter.notifyMatchFailure(
334  op.getLoc(), "floating-point type is not supported by EmitC");
335  }
336 
337  rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338  adaptedOp);
339  return success();
340  }
341 };
342 
343 template <typename ArithOp, bool castToUnsigned>
344 class CastConversion : public OpConversionPattern<ArithOp> {
345 public:
347 
348  LogicalResult
349  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350  ConversionPatternRewriter &rewriter) const override {
351 
352  Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353  if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
354  emitc::isPointerWideType(opReturnType)))
355  return rewriter.notifyMatchFailure(
356  op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357 
358  if (adaptor.getOperands().size() != 1) {
359  return rewriter.notifyMatchFailure(
360  op, "CastConversion only supports unary ops");
361  }
362 
363  Type operandType = adaptor.getIn().getType();
364  if (!operandType || !(isa<IntegerType>(operandType) ||
365  emitc::isPointerWideType(operandType)))
366  return rewriter.notifyMatchFailure(
367  op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368 
369  // Signed (sign-extending) casts from i1 are not supported.
370  if (operandType.isInteger(1) && !castToUnsigned)
371  return rewriter.notifyMatchFailure(op,
372  "operation not supported on i1 type");
373 
374  // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375  // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376  // truncation.
377  if (opReturnType.isInteger(1)) {
378  Type attrType = (emitc::isPointerWideType(operandType))
379  ? rewriter.getIndexType()
380  : operandType;
381  auto constOne = emitc::ConstantOp::create(
382  rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383  auto oneAndOperand = emitc::BitwiseAndOp::create(
384  rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
385  rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386  oneAndOperand);
387  return success();
388  }
389 
390  bool isTruncation =
391  (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
392  operandType.getIntOrFloatBitWidth() >
393  opReturnType.getIntOrFloatBitWidth());
394  bool doUnsigned = castToUnsigned || isTruncation;
395 
396  // Adapt the signedness of the result (bitwidth-preserving cast)
397  // This is needed e.g., if the return type is signless.
398  Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
399 
400  // Adapt the signedness of the operand (bitwidth-preserving cast)
401  Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402  Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403 
404  // Actual cast (may change bitwidth)
405  auto cast =
406  emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
407 
408  // Cast to the expected output type
409  auto result = adaptValueType(cast, rewriter, opReturnType);
410 
411  rewriter.replaceOp(op, result);
412  return success();
413  }
414 };
415 
416 template <typename ArithOp>
417 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418  using CastConversion<ArithOp, true>::CastConversion;
419 };
420 
421 template <typename ArithOp>
422 class SignedCastConversion : public CastConversion<ArithOp, false> {
423  using CastConversion<ArithOp, false>::CastConversion;
424 };
425 
426 template <typename ArithOp, typename EmitCOp>
427 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428 public:
430 
431  LogicalResult
432  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433  ConversionPatternRewriter &rewriter) const override {
434 
435  Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436  if (!newTy)
437  return rewriter.notifyMatchFailure(arithOp,
438  "converting result type failed");
439  rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440  adaptor.getOperands());
441 
442  return success();
443  }
444 };
445 
446 template <class ArithOp, class EmitCOp>
447 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455  if (!newRetTy)
456  return rewriter.notifyMatchFailure(uiBinOp,
457  "converting result type failed");
458  if (!isa<IntegerType>(newRetTy)) {
459  return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460  }
461  Type unsignedType =
462  adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
463  if (!unsignedType)
464  return rewriter.notifyMatchFailure(uiBinOp,
465  "converting result type failed");
466  Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467  Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468 
469  auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
470  ArrayRef<Value>{lhsAdapted, rhsAdapted});
471  Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
472  rewriter.replaceOp(uiBinOp, resultAdapted);
473  return success();
474  }
475 };
476 
477 template <typename ArithOp, typename EmitCOp>
478 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
479 public:
481 
482  LogicalResult
483  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485 
486  Type type = this->getTypeConverter()->convertType(op.getType());
487  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
488  return rewriter.notifyMatchFailure(
489  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
490  }
491 
492  if (type.isInteger(1)) {
493  // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
494  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
495  }
496 
497  Type arithmeticType = type;
498  if ((type.isSignlessInteger() || type.isSignedInteger()) &&
499  !bitEnumContainsAll(op.getOverflowFlags(),
500  arith::IntegerOverflowFlags::nsw)) {
501  // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
502  // we compute in unsigned integers to avoid UB.
503  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
504  /*isSigned=*/false);
505  }
506 
507  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
508  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
509 
510  Value arithmeticResult =
511  EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
512 
513  Value result = adaptValueType(arithmeticResult, rewriter, type);
514 
515  rewriter.replaceOp(op, result);
516  return success();
517  }
518 };
519 
520 template <typename ArithOp, typename EmitCOp>
521 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
522 public:
524 
525  LogicalResult
526  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
527  ConversionPatternRewriter &rewriter) const override {
528 
529  Type type = this->getTypeConverter()->convertType(op.getType());
530  if (!isa_and_nonnull<IntegerType>(type)) {
531  return rewriter.notifyMatchFailure(
532  op,
533  "expected integer type, vector/tensor support not yet implemented");
534  }
535 
536  // Bitwise ops can be performed directly on booleans
537  if (type.isInteger(1)) {
538  rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
539  adaptor.getRhs());
540  return success();
541  }
542 
543  // Bitwise ops are defined by the C standard on unsigned operands.
544  Type arithmeticType =
545  adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
546 
547  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
548  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
549 
550  Value arithmeticResult =
551  EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
552 
553  Value result = adaptValueType(arithmeticResult, rewriter, type);
554 
555  rewriter.replaceOp(op, result);
556  return success();
557  }
558 };
559 
560 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
561 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
562 public:
564 
565  LogicalResult
566  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override {
568 
569  Type type = this->getTypeConverter()->convertType(op.getType());
570  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
571  return rewriter.notifyMatchFailure(
572  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
573  }
574 
575  if (type.isInteger(1)) {
576  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
577  }
578 
579  Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
580 
581  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
582  // Shift amount interpreted as unsigned per Arith dialect spec.
583  Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
584  /*needsUnsigned=*/true);
585  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
586 
587  // Add a runtime check for overflow
588  Value width;
589  if (emitc::isPointerWideType(type)) {
590  Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
591  rewriter.getIndexAttr(8));
592  emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
593  rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
594  width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
595  sizeOfCall.getResult(0));
596  } else {
597  width = emitc::ConstantOp::create(
598  rewriter, op.getLoc(), rhsType,
599  rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
600  }
601 
602  Value excessCheck =
603  emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
604  emitc::CmpPredicate::lt, rhs, width);
605 
606  // Any concrete value is a valid refinement of poison.
607  Value poison = emitc::ConstantOp::create(
608  rewriter, op.getLoc(), arithmeticType,
609  (isa<IntegerType>(arithmeticType)
610  ? rewriter.getIntegerAttr(arithmeticType, 0)
611  : rewriter.getIndexAttr(0)));
612 
613  emitc::ExpressionOp ternary =
614  emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
615  ValueRange({lhs, rhs, excessCheck, poison}),
616  /*do_not_inline=*/false);
617  Block &bodyBlock = ternary.createBody();
618  auto currentPoint = rewriter.getInsertionPoint();
619  rewriter.setInsertionPointToStart(&bodyBlock);
620  Value arithmeticResult =
621  EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
622  bodyBlock.getArgument(0), bodyBlock.getArgument(1));
623  Value resultOrPoison = emitc::ConditionalOp::create(
624  rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
625  arithmeticResult, bodyBlock.getArgument(3));
626  emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
627  rewriter.setInsertionPoint(op->getBlock(), currentPoint);
628 
629  Value result = adaptValueType(ternary, rewriter, type);
630 
631  rewriter.replaceOp(op, result);
632  return success();
633  }
634 };
635 
636 template <typename ArithOp, typename EmitCOp>
637 class SignedShiftOpConversion final
638  : public ShiftOpConversion<ArithOp, EmitCOp, false> {
639  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
640 };
641 
642 template <typename ArithOp, typename EmitCOp>
643 class UnsignedShiftOpConversion final
644  : public ShiftOpConversion<ArithOp, EmitCOp, true> {
645  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
646 };
647 
648 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
649 public:
650  using Base::Base;
651 
652  LogicalResult
653  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
654  ConversionPatternRewriter &rewriter) const override {
655 
656  Type dstType = getTypeConverter()->convertType(selectOp.getType());
657  if (!dstType)
658  return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
659 
660  if (!adaptor.getCondition().getType().isInteger(1))
661  return rewriter.notifyMatchFailure(
662  selectOp,
663  "can only be converted if condition is a scalar of type i1");
664 
665  rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
666  adaptor.getOperands());
667 
668  return success();
669  }
670 };
671 
672 // Floating-point to integer conversions.
673 template <typename CastOp>
674 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
675 public:
676  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
677  : OpConversionPattern<CastOp>(typeConverter, context) {}
678 
679  LogicalResult
680  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
681  ConversionPatternRewriter &rewriter) const override {
682 
683  Type operandType = adaptor.getIn().getType();
684  if (!emitc::isSupportedFloatType(operandType))
685  return rewriter.notifyMatchFailure(castOp,
686  "unsupported cast source type");
687 
688  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
689  if (!dstType)
690  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
691 
692  // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
693  // truncated to 0, whereas a boolean conversion would return true.
694  if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
695  return rewriter.notifyMatchFailure(castOp,
696  "unsupported cast destination type");
697 
698  // Convert to unsigned if it's the "ui" variant
699  // Signless is interpreted as signed, so no need to cast for "si"
700  Type actualResultType = dstType;
701  if (isa<arith::FPToUIOp>(castOp)) {
702  actualResultType =
703  rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
704  /*isSigned=*/false);
705  }
706 
707  Value result = emitc::CastOp::create(
708  rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
709 
710  if (isa<arith::FPToUIOp>(castOp)) {
711  result =
712  emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
713  }
714  rewriter.replaceOp(castOp, result);
715 
716  return success();
717  }
718 };
719 
720 // Integer to floating-point conversions.
721 template <typename CastOp>
722 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
723 public:
724  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
725  : OpConversionPattern<CastOp>(typeConverter, context) {}
726 
727  LogicalResult
728  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
729  ConversionPatternRewriter &rewriter) const override {
730  // Vectors in particular are not supported
731  Type operandType = adaptor.getIn().getType();
732  if (!emitc::isSupportedIntegerType(operandType))
733  return rewriter.notifyMatchFailure(castOp,
734  "unsupported cast source type");
735 
736  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
737  if (!dstType)
738  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
739 
740  if (!emitc::isSupportedFloatType(dstType))
741  return rewriter.notifyMatchFailure(castOp,
742  "unsupported cast destination type");
743 
744  // Convert to unsigned if it's the "ui" variant
745  // Signless is interpreted as signed, so no need to cast for "si"
746  Type actualOperandType = operandType;
747  if (isa<arith::UIToFPOp>(castOp)) {
748  actualOperandType =
749  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
750  /*isSigned=*/false);
751  }
752  Value fpCastOperand = adaptor.getIn();
753  if (actualOperandType != operandType) {
754  fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
755  actualOperandType, fpCastOperand);
756  }
757  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
758 
759  return success();
760  }
761 };
762 
763 // Floating-point to floating-point conversions.
764 template <typename CastOp>
765 class FpCastOpConversion : public OpConversionPattern<CastOp> {
766 public:
767  FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
768  : OpConversionPattern<CastOp>(typeConverter, context) {}
769 
770  LogicalResult
771  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
772  ConversionPatternRewriter &rewriter) const override {
773  // Vectors in particular are not supported.
774  Type operandType = adaptor.getIn().getType();
775  if (!emitc::isSupportedFloatType(operandType))
776  return rewriter.notifyMatchFailure(castOp,
777  "unsupported cast source type");
778  if (auto roundingModeOp =
779  dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
780  // Only supporting default rounding mode as of now.
781  if (roundingModeOp.getRoundingModeAttr())
782  return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
783  }
784 
785  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
786  if (!dstType)
787  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
788 
789  if (!emitc::isSupportedFloatType(dstType))
790  return rewriter.notifyMatchFailure(castOp,
791  "unsupported cast destination type");
792 
793  Value fpCastOperand = adaptor.getIn();
794  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
795 
796  return success();
797  }
798 };
799 
800 } // namespace
801 
802 //===----------------------------------------------------------------------===//
803 // Pattern population
804 //===----------------------------------------------------------------------===//
805 
808  MLIRContext *ctx = patterns.getContext();
809 
811 
812  // clang-format off
813  patterns.add<
814  ArithConstantOpConversionPattern,
815  ArithOpConversion<arith::AddFOp, emitc::AddOp>,
816  ArithOpConversion<arith::DivFOp, emitc::DivOp>,
817  ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
818  ArithOpConversion<arith::MulFOp, emitc::MulOp>,
819  ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
820  ArithOpConversion<arith::SubFOp, emitc::SubOp>,
821  BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
822  BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
823  IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
824  IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
825  IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
826  BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
827  BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
828  BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
829  UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
830  SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
831  UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
832  CmpFOpConversion,
833  CmpIOpConversion,
834  NegFOpConversion,
835  SelectOpConversion,
836  // Truncation is guaranteed for unsigned types.
837  UnsignedCastConversion<arith::TruncIOp>,
838  SignedCastConversion<arith::ExtSIOp>,
839  UnsignedCastConversion<arith::ExtUIOp>,
840  SignedCastConversion<arith::IndexCastOp>,
841  UnsignedCastConversion<arith::IndexCastUIOp>,
842  ItoFCastOpConversion<arith::SIToFPOp>,
843  ItoFCastOpConversion<arith::UIToFPOp>,
844  FtoICastOpConversion<arith::FPToSIOp>,
845  FtoICastOpConversion<arith::FPToUIOp>,
846  FpCastOpConversion<arith::ExtFOp>,
847  FpCastOpConversion<arith::TruncFOp>
848  >(typeConverter, ctx);
849  // clang-format on
850 }
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:227
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:99
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:341
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:445
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:525
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:116
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:131
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:95
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)