MLIR  22.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/BuiltinTypes.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Implement the interface to convert Arith to EmitC.
28 struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
30 
31  /// Hook for derived dialect interface to provide conversion patterns
32  /// and mark dialect legal for the conversion target.
33  void populateConvertToEmitCConversionPatterns(
34  ConversionTarget &target, TypeConverter &typeConverter,
35  RewritePatternSet &patterns) const final {
36  populateArithToEmitCPatterns(typeConverter, patterns);
37  }
38 };
39 } // namespace
40 
42  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43  dialect->addInterfaces<ArithToEmitCDialectInterface>();
44  });
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Conversion Patterns
49 //===----------------------------------------------------------------------===//
50 
51 namespace {
52 class ArithConstantOpConversionPattern
53  : public OpConversionPattern<arith::ConstantOp> {
54 public:
56 
57  LogicalResult
58  matchAndRewrite(arith::ConstantOp arithConst,
59  arith::ConstantOp::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
61  Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62  if (!newTy)
63  return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64  rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65  adaptor.getValue());
66  return success();
67  }
68 };
69 
70 /// Get the signed or unsigned type corresponding to \p ty.
71 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72  if (isa<IntegerType>(ty)) {
73  if (ty.isUnsignedInteger() != needsUnsigned) {
74  auto signedness = needsUnsigned
75  ? IntegerType::SignednessSemantics::Unsigned
78  signedness);
79  }
80  } else if (emitc::isPointerWideType(ty)) {
81  if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82  if (needsUnsigned)
83  return emitc::SizeTType::get(ty.getContext());
85  }
86  }
87  return ty;
88 }
89 
90 /// Insert a cast operation to type \p ty if \p val does not have this type.
91 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93 }
94 
95 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96 public:
98 
99  LogicalResult
100  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101  ConversionPatternRewriter &rewriter) const override {
102 
103  if (!isa<FloatType>(adaptor.getRhs().getType())) {
104  return rewriter.notifyMatchFailure(op.getLoc(),
105  "cmpf currently only supported on "
106  "floats, not tensors/vectors thereof");
107  }
108 
109  bool unordered = false;
110  emitc::CmpPredicate predicate;
111  switch (op.getPredicate()) {
112  case arith::CmpFPredicate::AlwaysFalse: {
113  auto constant =
114  emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
115  rewriter.getBoolAttr(/*value=*/false));
116  rewriter.replaceOp(op, constant);
117  return success();
118  }
119  case arith::CmpFPredicate::OEQ:
120  unordered = false;
121  predicate = emitc::CmpPredicate::eq;
122  break;
123  case arith::CmpFPredicate::OGT:
124  unordered = false;
125  predicate = emitc::CmpPredicate::gt;
126  break;
127  case arith::CmpFPredicate::OGE:
128  unordered = false;
129  predicate = emitc::CmpPredicate::ge;
130  break;
131  case arith::CmpFPredicate::OLT:
132  unordered = false;
133  predicate = emitc::CmpPredicate::lt;
134  break;
135  case arith::CmpFPredicate::OLE:
136  unordered = false;
137  predicate = emitc::CmpPredicate::le;
138  break;
139  case arith::CmpFPredicate::ONE:
140  unordered = false;
141  predicate = emitc::CmpPredicate::ne;
142  break;
143  case arith::CmpFPredicate::ORD: {
144  // ordered, i.e. none of the operands is NaN
145  auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
146  adaptor.getRhs());
147  rewriter.replaceOp(op, cmp);
148  return success();
149  }
150  case arith::CmpFPredicate::UEQ:
151  unordered = true;
152  predicate = emitc::CmpPredicate::eq;
153  break;
154  case arith::CmpFPredicate::UGT:
155  unordered = true;
156  predicate = emitc::CmpPredicate::gt;
157  break;
158  case arith::CmpFPredicate::UGE:
159  unordered = true;
160  predicate = emitc::CmpPredicate::ge;
161  break;
162  case arith::CmpFPredicate::ULT:
163  unordered = true;
164  predicate = emitc::CmpPredicate::lt;
165  break;
166  case arith::CmpFPredicate::ULE:
167  unordered = true;
168  predicate = emitc::CmpPredicate::le;
169  break;
170  case arith::CmpFPredicate::UNE:
171  unordered = true;
172  predicate = emitc::CmpPredicate::ne;
173  break;
174  case arith::CmpFPredicate::UNO: {
175  // unordered, i.e. either operand is nan
176  auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
177  adaptor.getRhs());
178  rewriter.replaceOp(op, cmp);
179  return success();
180  }
181  case arith::CmpFPredicate::AlwaysTrue: {
182  auto constant =
183  emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
184  rewriter.getBoolAttr(/*value=*/true));
185  rewriter.replaceOp(op, constant);
186  return success();
187  }
188  }
189 
190  // Compare the values naively
191  auto cmpResult =
192  emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
193  adaptor.getLhs(), adaptor.getRhs());
194 
195  // Adjust the results for unordered/ordered semantics
196  if (unordered) {
197  auto isUnordered = createCheckIsUnordered(
198  rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
199  rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200  isUnordered, cmpResult);
201  return success();
202  }
203 
204  auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205  adaptor.getLhs(), adaptor.getRhs());
206  rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207  isOrdered, cmpResult);
208  return success();
209  }
210 
211 private:
212  /// Return a value that is true if \p operand is NaN.
213  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214  Value operand) const {
215  // A value is NaN exactly when it compares unequal to itself.
216  return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
217  emitc::CmpPredicate::ne, operand, operand);
218  }
219 
220  /// Return a value that is true if \p operand is not NaN.
221  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222  Value operand) const {
223  // A value is not NaN exactly when it compares equal to itself.
224  return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
225  emitc::CmpPredicate::eq, operand, operand);
226  }
227 
228  /// Return a value that is true if the operands \p first and \p second are
229  /// unordered (i.e., at least one of them is NaN).
230  Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231  Location loc, Value first, Value second) const {
232  auto firstIsNaN = isNaN(rewriter, loc, first);
233  auto secondIsNaN = isNaN(rewriter, loc, second);
234  return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
235  firstIsNaN, secondIsNaN);
236  }
237 
238  /// Return a value that is true if the operands \p first and \p second are
239  /// both ordered (i.e., none one of them is NaN).
240  Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241  Value first, Value second) const {
242  auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243  auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244  return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
245  firstIsNotNaN, secondIsNotNaN);
246  }
247 };
248 
249 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250 public:
252 
253  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254  switch (pred) {
255  case arith::CmpIPredicate::eq:
256  case arith::CmpIPredicate::ne:
257  case arith::CmpIPredicate::slt:
258  case arith::CmpIPredicate::sle:
259  case arith::CmpIPredicate::sgt:
260  case arith::CmpIPredicate::sge:
261  return false;
262  case arith::CmpIPredicate::ult:
263  case arith::CmpIPredicate::ule:
264  case arith::CmpIPredicate::ugt:
265  case arith::CmpIPredicate::uge:
266  return true;
267  }
268  llvm_unreachable("unknown cmpi predicate kind");
269  }
270 
271  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272  switch (pred) {
273  case arith::CmpIPredicate::eq:
274  return emitc::CmpPredicate::eq;
275  case arith::CmpIPredicate::ne:
276  return emitc::CmpPredicate::ne;
277  case arith::CmpIPredicate::slt:
278  case arith::CmpIPredicate::ult:
279  return emitc::CmpPredicate::lt;
280  case arith::CmpIPredicate::sle:
281  case arith::CmpIPredicate::ule:
282  return emitc::CmpPredicate::le;
283  case arith::CmpIPredicate::sgt:
284  case arith::CmpIPredicate::ugt:
285  return emitc::CmpPredicate::gt;
286  case arith::CmpIPredicate::sge:
287  case arith::CmpIPredicate::uge:
288  return emitc::CmpPredicate::ge;
289  }
290  llvm_unreachable("unknown cmpi predicate kind");
291  }
292 
293  LogicalResult
294  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296 
297  Type type = adaptor.getLhs().getType();
298  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
299  return rewriter.notifyMatchFailure(
300  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301  }
302 
303  bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304  emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305 
306  Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309 
310  rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311  return success();
312  }
313 };
314 
315 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316 public:
318 
319  LogicalResult
320  matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321  ConversionPatternRewriter &rewriter) const override {
322 
323  auto adaptedOp = adaptor.getOperand();
324  auto adaptedOpType = adaptedOp.getType();
325 
326  if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327  return rewriter.notifyMatchFailure(
328  op.getLoc(),
329  "negf currently only supports scalar types, not vectors or tensors");
330  }
331 
332  if (!emitc::isSupportedFloatType(adaptedOpType)) {
333  return rewriter.notifyMatchFailure(
334  op.getLoc(), "floating-point type is not supported by EmitC");
335  }
336 
337  rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338  adaptedOp);
339  return success();
340  }
341 };
342 
343 template <typename ArithOp, bool castToUnsigned>
344 class CastConversion : public OpConversionPattern<ArithOp> {
345 public:
347 
348  LogicalResult
349  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350  ConversionPatternRewriter &rewriter) const override {
351 
352  Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353  if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
354  emitc::isPointerWideType(opReturnType)))
355  return rewriter.notifyMatchFailure(
356  op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357 
358  if (adaptor.getOperands().size() != 1) {
359  return rewriter.notifyMatchFailure(
360  op, "CastConversion only supports unary ops");
361  }
362 
363  Type operandType = adaptor.getIn().getType();
364  if (!operandType || !(isa<IntegerType>(operandType) ||
365  emitc::isPointerWideType(operandType)))
366  return rewriter.notifyMatchFailure(
367  op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368 
369  // Signed (sign-extending) casts from i1 are not supported.
370  if (operandType.isInteger(1) && !castToUnsigned)
371  return rewriter.notifyMatchFailure(op,
372  "operation not supported on i1 type");
373 
374  // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375  // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376  // truncation.
377  if (opReturnType.isInteger(1)) {
378  Type attrType = (emitc::isPointerWideType(operandType))
379  ? rewriter.getIndexType()
380  : operandType;
381  auto constOne = emitc::ConstantOp::create(
382  rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383  auto oneAndOperand = emitc::BitwiseAndOp::create(
384  rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
385  rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386  oneAndOperand);
387  return success();
388  }
389 
390  bool isTruncation =
391  (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
392  operandType.getIntOrFloatBitWidth() >
393  opReturnType.getIntOrFloatBitWidth());
394  bool doUnsigned = castToUnsigned || isTruncation;
395 
396  // Adapt the signedness of the result (bitwidth-preserving cast)
397  // This is needed e.g., if the return type is signless.
398  Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
399 
400  // Adapt the signedness of the operand (bitwidth-preserving cast)
401  Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402  Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403 
404  // Actual cast (may change bitwidth)
405  auto cast =
406  emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
407 
408  // Cast to the expected output type
409  auto result = adaptValueType(cast, rewriter, opReturnType);
410 
411  rewriter.replaceOp(op, result);
412  return success();
413  }
414 };
415 
416 template <typename ArithOp>
417 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418  using CastConversion<ArithOp, true>::CastConversion;
419 };
420 
421 template <typename ArithOp>
422 class SignedCastConversion : public CastConversion<ArithOp, false> {
423  using CastConversion<ArithOp, false>::CastConversion;
424 };
425 
426 template <typename ArithOp, typename EmitCOp>
427 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428 public:
430 
431  LogicalResult
432  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433  ConversionPatternRewriter &rewriter) const override {
434 
435  Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436  if (!newTy)
437  return rewriter.notifyMatchFailure(arithOp,
438  "converting result type failed");
439  rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440  adaptor.getOperands());
441 
442  return success();
443  }
444 };
445 
446 template <class ArithOp, class EmitCOp>
447 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455  if (!newRetTy)
456  return rewriter.notifyMatchFailure(uiBinOp,
457  "converting result type failed");
458  if (!isa<IntegerType>(newRetTy)) {
459  return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460  }
461  Type unsignedType =
462  adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
463  if (!unsignedType)
464  return rewriter.notifyMatchFailure(uiBinOp,
465  "converting result type failed");
466  Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467  Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468 
469  auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
470  ArrayRef<Value>{lhsAdapted, rhsAdapted});
471  Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
472  rewriter.replaceOp(uiBinOp, resultAdapted);
473  return success();
474  }
475 };
476 
477 template <typename ArithOp, typename EmitCOp>
478 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
479 public:
481 
482  LogicalResult
483  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
484  ConversionPatternRewriter &rewriter) const override {
485 
486  Type type = this->getTypeConverter()->convertType(op.getType());
487  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
488  return rewriter.notifyMatchFailure(
489  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
490  }
491 
492  if (type.isInteger(1)) {
493  // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
494  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
495  }
496 
497  Type arithmeticType = type;
498  if ((type.isSignlessInteger() || type.isSignedInteger()) &&
499  !bitEnumContainsAll(op.getOverflowFlags(),
500  arith::IntegerOverflowFlags::nsw)) {
501  // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
502  // we compute in unsigned integers to avoid UB.
503  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
504  /*isSigned=*/false);
505  }
506 
507  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
508  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
509 
510  Value arithmeticResult =
511  EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
512 
513  Value result = adaptValueType(arithmeticResult, rewriter, type);
514 
515  rewriter.replaceOp(op, result);
516  return success();
517  }
518 };
519 
520 template <typename ArithOp, typename EmitCOp>
521 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
522 public:
524 
525  LogicalResult
526  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
527  ConversionPatternRewriter &rewriter) const override {
528 
529  Type type = this->getTypeConverter()->convertType(op.getType());
530  if (!isa_and_nonnull<IntegerType>(type)) {
531  return rewriter.notifyMatchFailure(
532  op,
533  "expected integer type, vector/tensor support not yet implemented");
534  }
535 
536  // Bitwise ops can be performed directly on booleans
537  if (type.isInteger(1)) {
538  rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
539  adaptor.getRhs());
540  return success();
541  }
542 
543  // Bitwise ops are defined by the C standard on unsigned operands.
544  Type arithmeticType =
545  adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
546 
547  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
548  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
549 
550  Value arithmeticResult =
551  EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
552 
553  Value result = adaptValueType(arithmeticResult, rewriter, type);
554 
555  rewriter.replaceOp(op, result);
556  return success();
557  }
558 };
559 
560 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
561 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
562 public:
564 
565  LogicalResult
566  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
567  ConversionPatternRewriter &rewriter) const override {
568 
569  Type type = this->getTypeConverter()->convertType(op.getType());
570  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
571  return rewriter.notifyMatchFailure(
572  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
573  }
574 
575  if (type.isInteger(1)) {
576  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
577  }
578 
579  Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
580 
581  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
582  // Shift amount interpreted as unsigned per Arith dialect spec.
583  Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
584  /*needsUnsigned=*/true);
585  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
586 
587  // Add a runtime check for overflow
588  Value width;
589  if (emitc::isPointerWideType(type)) {
590  Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
591  rewriter.getIndexAttr(8));
592  emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
593  rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
594  width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
595  sizeOfCall.getResult(0));
596  } else {
597  width = emitc::ConstantOp::create(
598  rewriter, op.getLoc(), rhsType,
599  rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
600  }
601 
602  Value excessCheck =
603  emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
604  emitc::CmpPredicate::lt, rhs, width);
605 
606  // Any concrete value is a valid refinement of poison.
607  Value poison = emitc::ConstantOp::create(
608  rewriter, op.getLoc(), arithmeticType,
609  (isa<IntegerType>(arithmeticType)
610  ? rewriter.getIntegerAttr(arithmeticType, 0)
611  : rewriter.getIndexAttr(0)));
612 
613  emitc::ExpressionOp ternary = emitc::ExpressionOp::create(
614  rewriter, op.getLoc(), arithmeticType, /*do_not_inline=*/false);
615  Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
616  auto currentPoint = rewriter.getInsertionPoint();
617  rewriter.setInsertionPointToStart(&bodyBlock);
618  Value arithmeticResult =
619  EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
620  Value resultOrPoison =
621  emitc::ConditionalOp::create(rewriter, op.getLoc(), arithmeticType,
622  excessCheck, arithmeticResult, poison);
623  emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
624  rewriter.setInsertionPoint(op->getBlock(), currentPoint);
625 
626  Value result = adaptValueType(ternary, rewriter, type);
627 
628  rewriter.replaceOp(op, result);
629  return success();
630  }
631 };
632 
633 template <typename ArithOp, typename EmitCOp>
634 class SignedShiftOpConversion final
635  : public ShiftOpConversion<ArithOp, EmitCOp, false> {
636  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
637 };
638 
639 template <typename ArithOp, typename EmitCOp>
640 class UnsignedShiftOpConversion final
641  : public ShiftOpConversion<ArithOp, EmitCOp, true> {
642  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
643 };
644 
645 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
646 public:
648 
649  LogicalResult
650  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
651  ConversionPatternRewriter &rewriter) const override {
652 
653  Type dstType = getTypeConverter()->convertType(selectOp.getType());
654  if (!dstType)
655  return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
656 
657  if (!adaptor.getCondition().getType().isInteger(1))
658  return rewriter.notifyMatchFailure(
659  selectOp,
660  "can only be converted if condition is a scalar of type i1");
661 
662  rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
663  adaptor.getOperands());
664 
665  return success();
666  }
667 };
668 
669 // Floating-point to integer conversions.
670 template <typename CastOp>
671 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
672 public:
673  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
674  : OpConversionPattern<CastOp>(typeConverter, context) {}
675 
676  LogicalResult
677  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
678  ConversionPatternRewriter &rewriter) const override {
679 
680  Type operandType = adaptor.getIn().getType();
681  if (!emitc::isSupportedFloatType(operandType))
682  return rewriter.notifyMatchFailure(castOp,
683  "unsupported cast source type");
684 
685  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
686  if (!dstType)
687  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
688 
689  // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
690  // truncated to 0, whereas a boolean conversion would return true.
691  if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
692  return rewriter.notifyMatchFailure(castOp,
693  "unsupported cast destination type");
694 
695  // Convert to unsigned if it's the "ui" variant
696  // Signless is interpreted as signed, so no need to cast for "si"
697  Type actualResultType = dstType;
698  if (isa<arith::FPToUIOp>(castOp)) {
699  actualResultType =
700  rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
701  /*isSigned=*/false);
702  }
703 
704  Value result = emitc::CastOp::create(
705  rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
706 
707  if (isa<arith::FPToUIOp>(castOp)) {
708  result =
709  emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
710  }
711  rewriter.replaceOp(castOp, result);
712 
713  return success();
714  }
715 };
716 
717 // Integer to floating-point conversions.
718 template <typename CastOp>
719 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
720 public:
721  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
722  : OpConversionPattern<CastOp>(typeConverter, context) {}
723 
724  LogicalResult
725  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
726  ConversionPatternRewriter &rewriter) const override {
727  // Vectors in particular are not supported
728  Type operandType = adaptor.getIn().getType();
729  if (!emitc::isSupportedIntegerType(operandType))
730  return rewriter.notifyMatchFailure(castOp,
731  "unsupported cast source type");
732 
733  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
734  if (!dstType)
735  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
736 
737  if (!emitc::isSupportedFloatType(dstType))
738  return rewriter.notifyMatchFailure(castOp,
739  "unsupported cast destination type");
740 
741  // Convert to unsigned if it's the "ui" variant
742  // Signless is interpreted as signed, so no need to cast for "si"
743  Type actualOperandType = operandType;
744  if (isa<arith::UIToFPOp>(castOp)) {
745  actualOperandType =
746  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
747  /*isSigned=*/false);
748  }
749  Value fpCastOperand = adaptor.getIn();
750  if (actualOperandType != operandType) {
751  fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
752  actualOperandType, fpCastOperand);
753  }
754  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
755 
756  return success();
757  }
758 };
759 
760 // Floating-point to floating-point conversions.
761 template <typename CastOp>
762 class FpCastOpConversion : public OpConversionPattern<CastOp> {
763 public:
764  FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
765  : OpConversionPattern<CastOp>(typeConverter, context) {}
766 
767  LogicalResult
768  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
769  ConversionPatternRewriter &rewriter) const override {
770  // Vectors in particular are not supported.
771  Type operandType = adaptor.getIn().getType();
772  if (!emitc::isSupportedFloatType(operandType))
773  return rewriter.notifyMatchFailure(castOp,
774  "unsupported cast source type");
775  if (auto roundingModeOp =
776  dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
777  // Only supporting default rounding mode as of now.
778  if (roundingModeOp.getRoundingModeAttr())
779  return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
780  }
781 
782  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
783  if (!dstType)
784  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
785 
786  if (!emitc::isSupportedFloatType(dstType))
787  return rewriter.notifyMatchFailure(castOp,
788  "unsupported cast destination type");
789 
790  Value fpCastOperand = adaptor.getIn();
791  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
792 
793  return success();
794  }
795 };
796 
797 } // namespace
798 
799 //===----------------------------------------------------------------------===//
800 // Pattern population
801 //===----------------------------------------------------------------------===//
802 
805  MLIRContext *ctx = patterns.getContext();
806 
808 
809  // clang-format off
810  patterns.add<
811  ArithConstantOpConversionPattern,
812  ArithOpConversion<arith::AddFOp, emitc::AddOp>,
813  ArithOpConversion<arith::DivFOp, emitc::DivOp>,
814  ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
815  ArithOpConversion<arith::MulFOp, emitc::MulOp>,
816  ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
817  ArithOpConversion<arith::SubFOp, emitc::SubOp>,
818  BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
819  BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
820  IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
821  IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
822  IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
823  BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
824  BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
825  BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
826  UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
827  SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
828  UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
829  CmpFOpConversion,
830  CmpIOpConversion,
831  NegFOpConversion,
832  SelectOpConversion,
833  // Truncation is guaranteed for unsigned types.
834  UnsignedCastConversion<arith::TruncIOp>,
835  SignedCastConversion<arith::ExtSIOp>,
836  UnsignedCastConversion<arith::ExtUIOp>,
837  SignedCastConversion<arith::IndexCastOp>,
838  UnsignedCastConversion<arith::IndexCastUIOp>,
839  ItoFCastOpConversion<arith::SIToFPOp>,
840  ItoFCastOpConversion<arith::UIToFPOp>,
841  FtoICastOpConversion<arith::FPToSIOp>,
842  FtoICastOpConversion<arith::FPToUIOp>,
843  FpCastOpConversion<arith::ExtFOp>,
844  FpCastOpConversion<arith::TruncFOp>
845  >(typeConverter, ctx);
846  // clang-format on
847 }
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:103
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:223
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:95
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:337
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:517
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:702
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:519
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:24
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:114
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:132
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:93
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)