MLIR  21.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
21 #include "mlir/IR/BuiltinTypes.h"
23 
24 using namespace mlir;
25 
26 namespace {
27 /// Implement the interface to convert Arith to EmitC.
28 struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
30 
31  /// Hook for derived dialect interface to provide conversion patterns
32  /// and mark dialect legal for the conversion target.
33  void populateConvertToEmitCConversionPatterns(
34  ConversionTarget &target, TypeConverter &typeConverter,
35  RewritePatternSet &patterns) const final {
36  populateArithToEmitCPatterns(typeConverter, patterns);
37  }
38 };
39 } // namespace
40 
42  registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43  dialect->addInterfaces<ArithToEmitCDialectInterface>();
44  });
45 }
46 
47 //===----------------------------------------------------------------------===//
48 // Conversion Patterns
49 //===----------------------------------------------------------------------===//
50 
51 namespace {
52 class ArithConstantOpConversionPattern
53  : public OpConversionPattern<arith::ConstantOp> {
54 public:
56 
57  LogicalResult
58  matchAndRewrite(arith::ConstantOp arithConst,
59  arith::ConstantOp::Adaptor adaptor,
60  ConversionPatternRewriter &rewriter) const override {
61  Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62  if (!newTy)
63  return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64  rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65  adaptor.getValue());
66  return success();
67  }
68 };
69 
70 /// Get the signed or unsigned type corresponding to \p ty.
71 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72  if (isa<IntegerType>(ty)) {
73  if (ty.isUnsignedInteger() != needsUnsigned) {
74  auto signedness = needsUnsigned
75  ? IntegerType::SignednessSemantics::Unsigned
78  signedness);
79  }
80  } else if (emitc::isPointerWideType(ty)) {
81  if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82  if (needsUnsigned)
83  return emitc::SizeTType::get(ty.getContext());
85  }
86  }
87  return ty;
88 }
89 
90 /// Insert a cast operation to type \p ty if \p val does not have this type.
91 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93 }
94 
95 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96 public:
98 
99  LogicalResult
100  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101  ConversionPatternRewriter &rewriter) const override {
102 
103  if (!isa<FloatType>(adaptor.getRhs().getType())) {
104  return rewriter.notifyMatchFailure(op.getLoc(),
105  "cmpf currently only supported on "
106  "floats, not tensors/vectors thereof");
107  }
108 
109  bool unordered = false;
110  emitc::CmpPredicate predicate;
111  switch (op.getPredicate()) {
112  case arith::CmpFPredicate::AlwaysFalse: {
113  auto constant = rewriter.create<emitc::ConstantOp>(
114  op.getLoc(), rewriter.getI1Type(),
115  rewriter.getBoolAttr(/*value=*/false));
116  rewriter.replaceOp(op, constant);
117  return success();
118  }
119  case arith::CmpFPredicate::OEQ:
120  unordered = false;
121  predicate = emitc::CmpPredicate::eq;
122  break;
123  case arith::CmpFPredicate::OGT:
124  unordered = false;
125  predicate = emitc::CmpPredicate::gt;
126  break;
127  case arith::CmpFPredicate::OGE:
128  unordered = false;
129  predicate = emitc::CmpPredicate::ge;
130  break;
131  case arith::CmpFPredicate::OLT:
132  unordered = false;
133  predicate = emitc::CmpPredicate::lt;
134  break;
135  case arith::CmpFPredicate::OLE:
136  unordered = false;
137  predicate = emitc::CmpPredicate::le;
138  break;
139  case arith::CmpFPredicate::ONE:
140  unordered = false;
141  predicate = emitc::CmpPredicate::ne;
142  break;
143  case arith::CmpFPredicate::ORD: {
144  // ordered, i.e. none of the operands is NaN
145  auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
146  adaptor.getRhs());
147  rewriter.replaceOp(op, cmp);
148  return success();
149  }
150  case arith::CmpFPredicate::UEQ:
151  unordered = true;
152  predicate = emitc::CmpPredicate::eq;
153  break;
154  case arith::CmpFPredicate::UGT:
155  unordered = true;
156  predicate = emitc::CmpPredicate::gt;
157  break;
158  case arith::CmpFPredicate::UGE:
159  unordered = true;
160  predicate = emitc::CmpPredicate::ge;
161  break;
162  case arith::CmpFPredicate::ULT:
163  unordered = true;
164  predicate = emitc::CmpPredicate::lt;
165  break;
166  case arith::CmpFPredicate::ULE:
167  unordered = true;
168  predicate = emitc::CmpPredicate::le;
169  break;
170  case arith::CmpFPredicate::UNE:
171  unordered = true;
172  predicate = emitc::CmpPredicate::ne;
173  break;
174  case arith::CmpFPredicate::UNO: {
175  // unordered, i.e. either operand is nan
176  auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
177  adaptor.getRhs());
178  rewriter.replaceOp(op, cmp);
179  return success();
180  }
181  case arith::CmpFPredicate::AlwaysTrue: {
182  auto constant = rewriter.create<emitc::ConstantOp>(
183  op.getLoc(), rewriter.getI1Type(),
184  rewriter.getBoolAttr(/*value=*/true));
185  rewriter.replaceOp(op, constant);
186  return success();
187  }
188  }
189 
190  // Compare the values naively
191  auto cmpResult =
192  rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
193  adaptor.getLhs(), adaptor.getRhs());
194 
195  // Adjust the results for unordered/ordered semantics
196  if (unordered) {
197  auto isUnordered = createCheckIsUnordered(
198  rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
199  rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200  isUnordered, cmpResult);
201  return success();
202  }
203 
204  auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205  adaptor.getLhs(), adaptor.getRhs());
206  rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207  isOrdered, cmpResult);
208  return success();
209  }
210 
211 private:
212  /// Return a value that is true if \p operand is NaN.
213  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214  Value operand) const {
215  // A value is NaN exactly when it compares unequal to itself.
216  return rewriter.create<emitc::CmpOp>(
217  loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
218  }
219 
220  /// Return a value that is true if \p operand is not NaN.
221  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222  Value operand) const {
223  // A value is not NaN exactly when it compares equal to itself.
224  return rewriter.create<emitc::CmpOp>(
225  loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
226  }
227 
228  /// Return a value that is true if the operands \p first and \p second are
229  /// unordered (i.e., at least one of them is NaN).
230  Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231  Location loc, Value first, Value second) const {
232  auto firstIsNaN = isNaN(rewriter, loc, first);
233  auto secondIsNaN = isNaN(rewriter, loc, second);
234  return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
235  firstIsNaN, secondIsNaN);
236  }
237 
238  /// Return a value that is true if the operands \p first and \p second are
239  /// both ordered (i.e., none one of them is NaN).
240  Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241  Value first, Value second) const {
242  auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243  auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244  return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
245  firstIsNotNaN, secondIsNotNaN);
246  }
247 };
248 
249 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250 public:
252 
253  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254  switch (pred) {
255  case arith::CmpIPredicate::eq:
256  case arith::CmpIPredicate::ne:
257  case arith::CmpIPredicate::slt:
258  case arith::CmpIPredicate::sle:
259  case arith::CmpIPredicate::sgt:
260  case arith::CmpIPredicate::sge:
261  return false;
262  case arith::CmpIPredicate::ult:
263  case arith::CmpIPredicate::ule:
264  case arith::CmpIPredicate::ugt:
265  case arith::CmpIPredicate::uge:
266  return true;
267  }
268  llvm_unreachable("unknown cmpi predicate kind");
269  }
270 
271  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272  switch (pred) {
273  case arith::CmpIPredicate::eq:
274  return emitc::CmpPredicate::eq;
275  case arith::CmpIPredicate::ne:
276  return emitc::CmpPredicate::ne;
277  case arith::CmpIPredicate::slt:
278  case arith::CmpIPredicate::ult:
279  return emitc::CmpPredicate::lt;
280  case arith::CmpIPredicate::sle:
281  case arith::CmpIPredicate::ule:
282  return emitc::CmpPredicate::le;
283  case arith::CmpIPredicate::sgt:
284  case arith::CmpIPredicate::ugt:
285  return emitc::CmpPredicate::gt;
286  case arith::CmpIPredicate::sge:
287  case arith::CmpIPredicate::uge:
288  return emitc::CmpPredicate::ge;
289  }
290  llvm_unreachable("unknown cmpi predicate kind");
291  }
292 
293  LogicalResult
294  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295  ConversionPatternRewriter &rewriter) const override {
296 
297  Type type = adaptor.getLhs().getType();
298  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
299  return rewriter.notifyMatchFailure(
300  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301  }
302 
303  bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304  emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305 
306  Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309 
310  rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311  return success();
312  }
313 };
314 
315 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316 public:
318 
319  LogicalResult
320  matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321  ConversionPatternRewriter &rewriter) const override {
322 
323  auto adaptedOp = adaptor.getOperand();
324  auto adaptedOpType = adaptedOp.getType();
325 
326  if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327  return rewriter.notifyMatchFailure(
328  op.getLoc(),
329  "negf currently only supports scalar types, not vectors or tensors");
330  }
331 
332  if (!emitc::isSupportedFloatType(adaptedOpType)) {
333  return rewriter.notifyMatchFailure(
334  op.getLoc(), "floating-point type is not supported by EmitC");
335  }
336 
337  rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338  adaptedOp);
339  return success();
340  }
341 };
342 
343 template <typename ArithOp, bool castToUnsigned>
344 class CastConversion : public OpConversionPattern<ArithOp> {
345 public:
347 
348  LogicalResult
349  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350  ConversionPatternRewriter &rewriter) const override {
351 
352  Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353  if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
354  emitc::isPointerWideType(opReturnType)))
355  return rewriter.notifyMatchFailure(
356  op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357 
358  if (adaptor.getOperands().size() != 1) {
359  return rewriter.notifyMatchFailure(
360  op, "CastConversion only supports unary ops");
361  }
362 
363  Type operandType = adaptor.getIn().getType();
364  if (!operandType || !(isa<IntegerType>(operandType) ||
365  emitc::isPointerWideType(operandType)))
366  return rewriter.notifyMatchFailure(
367  op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368 
369  // Signed (sign-extending) casts from i1 are not supported.
370  if (operandType.isInteger(1) && !castToUnsigned)
371  return rewriter.notifyMatchFailure(op,
372  "operation not supported on i1 type");
373 
374  // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375  // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376  // truncation.
377  if (opReturnType.isInteger(1)) {
378  Type attrType = (emitc::isPointerWideType(operandType))
379  ? rewriter.getIndexType()
380  : operandType;
381  auto constOne = rewriter.create<emitc::ConstantOp>(
382  op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383  auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
384  op.getLoc(), operandType, adaptor.getIn(), constOne);
385  rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386  oneAndOperand);
387  return success();
388  }
389 
390  bool isTruncation =
391  (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
392  operandType.getIntOrFloatBitWidth() >
393  opReturnType.getIntOrFloatBitWidth());
394  bool doUnsigned = castToUnsigned || isTruncation;
395 
396  // Adapt the signedness of the result (bitwidth-preserving cast)
397  // This is needed e.g., if the return type is signless.
398  Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
399 
400  // Adapt the signedness of the operand (bitwidth-preserving cast)
401  Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402  Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403 
404  // Actual cast (may change bitwidth)
405  auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
406  castDestType, actualOp);
407 
408  // Cast to the expected output type
409  auto result = adaptValueType(cast, rewriter, opReturnType);
410 
411  rewriter.replaceOp(op, result);
412  return success();
413  }
414 };
415 
416 template <typename ArithOp>
417 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418  using CastConversion<ArithOp, true>::CastConversion;
419 };
420 
421 template <typename ArithOp>
422 class SignedCastConversion : public CastConversion<ArithOp, false> {
423  using CastConversion<ArithOp, false>::CastConversion;
424 };
425 
426 template <typename ArithOp, typename EmitCOp>
427 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428 public:
430 
431  LogicalResult
432  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433  ConversionPatternRewriter &rewriter) const override {
434 
435  Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436  if (!newTy)
437  return rewriter.notifyMatchFailure(arithOp,
438  "converting result type failed");
439  rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440  adaptor.getOperands());
441 
442  return success();
443  }
444 };
445 
446 template <class ArithOp, class EmitCOp>
447 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448 public:
450 
451  LogicalResult
452  matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453  ConversionPatternRewriter &rewriter) const override {
454  Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455  if (!newRetTy)
456  return rewriter.notifyMatchFailure(uiBinOp,
457  "converting result type failed");
458  if (!isa<IntegerType>(newRetTy)) {
459  return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460  }
461  Type unsignedType =
462  adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
463  if (!unsignedType)
464  return rewriter.notifyMatchFailure(uiBinOp,
465  "converting result type failed");
466  Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467  Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468 
469  auto newDivOp =
470  rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
471  ArrayRef<Value>{lhsAdapted, rhsAdapted});
472  Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
473  rewriter.replaceOp(uiBinOp, resultAdapted);
474  return success();
475  }
476 };
477 
478 template <typename ArithOp, typename EmitCOp>
479 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
480 public:
482 
483  LogicalResult
484  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
485  ConversionPatternRewriter &rewriter) const override {
486 
487  Type type = this->getTypeConverter()->convertType(op.getType());
488  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
489  return rewriter.notifyMatchFailure(
490  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
491  }
492 
493  if (type.isInteger(1)) {
494  // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
495  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
496  }
497 
498  Type arithmeticType = type;
499  if ((type.isSignlessInteger() || type.isSignedInteger()) &&
500  !bitEnumContainsAll(op.getOverflowFlags(),
501  arith::IntegerOverflowFlags::nsw)) {
502  // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
503  // we compute in unsigned integers to avoid UB.
504  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
505  /*isSigned=*/false);
506  }
507 
508  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
509  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
510 
511  Value arithmeticResult = rewriter.template create<EmitCOp>(
512  op.getLoc(), arithmeticType, lhs, rhs);
513 
514  Value result = adaptValueType(arithmeticResult, rewriter, type);
515 
516  rewriter.replaceOp(op, result);
517  return success();
518  }
519 };
520 
521 template <typename ArithOp, typename EmitCOp>
522 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
523 public:
525 
526  LogicalResult
527  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
528  ConversionPatternRewriter &rewriter) const override {
529 
530  Type type = this->getTypeConverter()->convertType(op.getType());
531  if (!isa_and_nonnull<IntegerType>(type)) {
532  return rewriter.notifyMatchFailure(
533  op,
534  "expected integer type, vector/tensor support not yet implemented");
535  }
536 
537  // Bitwise ops can be performed directly on booleans
538  if (type.isInteger(1)) {
539  rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
540  adaptor.getRhs());
541  return success();
542  }
543 
544  // Bitwise ops are defined by the C standard on unsigned operands.
545  Type arithmeticType =
546  adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
547 
548  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
549  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
550 
551  Value arithmeticResult = rewriter.template create<EmitCOp>(
552  op.getLoc(), arithmeticType, lhs, rhs);
553 
554  Value result = adaptValueType(arithmeticResult, rewriter, type);
555 
556  rewriter.replaceOp(op, result);
557  return success();
558  }
559 };
560 
561 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
562 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
563 public:
565 
566  LogicalResult
567  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
568  ConversionPatternRewriter &rewriter) const override {
569 
570  Type type = this->getTypeConverter()->convertType(op.getType());
571  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
572  return rewriter.notifyMatchFailure(
573  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
574  }
575 
576  if (type.isInteger(1)) {
577  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
578  }
579 
580  Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
581 
582  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
583  // Shift amount interpreted as unsigned per Arith dialect spec.
584  Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
585  /*needsUnsigned=*/true);
586  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
587 
588  // Add a runtime check for overflow
589  Value width;
590  if (emitc::isPointerWideType(type)) {
591  Value eight = rewriter.create<emitc::ConstantOp>(
592  op.getLoc(), rhsType, rewriter.getIndexAttr(8));
593  emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
594  op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
595  width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
596  sizeOfCall.getResult(0));
597  } else {
598  width = rewriter.create<emitc::ConstantOp>(
599  op.getLoc(), rhsType,
600  rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
601  }
602 
603  Value excessCheck = rewriter.create<emitc::CmpOp>(
604  op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
605 
606  // Any concrete value is a valid refinement of poison.
607  Value poison = rewriter.create<emitc::ConstantOp>(
608  op.getLoc(), arithmeticType,
609  (isa<IntegerType>(arithmeticType)
610  ? rewriter.getIntegerAttr(arithmeticType, 0)
611  : rewriter.getIndexAttr(0)));
612 
613  emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
614  op.getLoc(), arithmeticType, /*do_not_inline=*/false);
615  Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
616  auto currentPoint = rewriter.getInsertionPoint();
617  rewriter.setInsertionPointToStart(&bodyBlock);
618  Value arithmeticResult =
619  rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
620  Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
621  op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
622  rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
623  rewriter.setInsertionPoint(op->getBlock(), currentPoint);
624 
625  Value result = adaptValueType(ternary, rewriter, type);
626 
627  rewriter.replaceOp(op, result);
628  return success();
629  }
630 };
631 
632 template <typename ArithOp, typename EmitCOp>
633 class SignedShiftOpConversion final
634  : public ShiftOpConversion<ArithOp, EmitCOp, false> {
635  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
636 };
637 
638 template <typename ArithOp, typename EmitCOp>
639 class UnsignedShiftOpConversion final
640  : public ShiftOpConversion<ArithOp, EmitCOp, true> {
641  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
642 };
643 
644 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
645 public:
647 
648  LogicalResult
649  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
650  ConversionPatternRewriter &rewriter) const override {
651 
652  Type dstType = getTypeConverter()->convertType(selectOp.getType());
653  if (!dstType)
654  return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
655 
656  if (!adaptor.getCondition().getType().isInteger(1))
657  return rewriter.notifyMatchFailure(
658  selectOp,
659  "can only be converted if condition is a scalar of type i1");
660 
661  rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
662  adaptor.getOperands());
663 
664  return success();
665  }
666 };
667 
668 // Floating-point to integer conversions.
669 template <typename CastOp>
670 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
671 public:
672  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
673  : OpConversionPattern<CastOp>(typeConverter, context) {}
674 
675  LogicalResult
676  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
677  ConversionPatternRewriter &rewriter) const override {
678 
679  Type operandType = adaptor.getIn().getType();
680  if (!emitc::isSupportedFloatType(operandType))
681  return rewriter.notifyMatchFailure(castOp,
682  "unsupported cast source type");
683 
684  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
685  if (!dstType)
686  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
687 
688  // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
689  // truncated to 0, whereas a boolean conversion would return true.
690  if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
691  return rewriter.notifyMatchFailure(castOp,
692  "unsupported cast destination type");
693 
694  // Convert to unsigned if it's the "ui" variant
695  // Signless is interpreted as signed, so no need to cast for "si"
696  Type actualResultType = dstType;
697  if (isa<arith::FPToUIOp>(castOp)) {
698  actualResultType =
699  rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
700  /*isSigned=*/false);
701  }
702 
703  Value result = rewriter.create<emitc::CastOp>(
704  castOp.getLoc(), actualResultType, adaptor.getOperands());
705 
706  if (isa<arith::FPToUIOp>(castOp)) {
707  result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
708  }
709  rewriter.replaceOp(castOp, result);
710 
711  return success();
712  }
713 };
714 
715 // Integer to floating-point conversions.
716 template <typename CastOp>
717 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
718 public:
719  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
720  : OpConversionPattern<CastOp>(typeConverter, context) {}
721 
722  LogicalResult
723  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
724  ConversionPatternRewriter &rewriter) const override {
725  // Vectors in particular are not supported
726  Type operandType = adaptor.getIn().getType();
727  if (!emitc::isSupportedIntegerType(operandType))
728  return rewriter.notifyMatchFailure(castOp,
729  "unsupported cast source type");
730 
731  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
732  if (!dstType)
733  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
734 
735  if (!emitc::isSupportedFloatType(dstType))
736  return rewriter.notifyMatchFailure(castOp,
737  "unsupported cast destination type");
738 
739  // Convert to unsigned if it's the "ui" variant
740  // Signless is interpreted as signed, so no need to cast for "si"
741  Type actualOperandType = operandType;
742  if (isa<arith::UIToFPOp>(castOp)) {
743  actualOperandType =
744  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
745  /*isSigned=*/false);
746  }
747  Value fpCastOperand = adaptor.getIn();
748  if (actualOperandType != operandType) {
749  fpCastOperand = rewriter.template create<emitc::CastOp>(
750  castOp.getLoc(), actualOperandType, fpCastOperand);
751  }
752  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
753 
754  return success();
755  }
756 };
757 
758 // Floating-point to floating-point conversions.
759 template <typename CastOp>
760 class FpCastOpConversion : public OpConversionPattern<CastOp> {
761 public:
762  FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
763  : OpConversionPattern<CastOp>(typeConverter, context) {}
764 
765  LogicalResult
766  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
767  ConversionPatternRewriter &rewriter) const override {
768  // Vectors in particular are not supported.
769  Type operandType = adaptor.getIn().getType();
770  if (!emitc::isSupportedFloatType(operandType))
771  return rewriter.notifyMatchFailure(castOp,
772  "unsupported cast source type");
773  if (auto roundingModeOp =
774  dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
775  // Only supporting default rounding mode as of now.
776  if (roundingModeOp.getRoundingModeAttr())
777  return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
778  }
779 
780  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
781  if (!dstType)
782  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
783 
784  if (!emitc::isSupportedFloatType(dstType))
785  return rewriter.notifyMatchFailure(castOp,
786  "unsupported cast destination type");
787 
788  Value fpCastOperand = adaptor.getIn();
789  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
790 
791  return success();
792  }
793 };
794 
795 } // namespace
796 
797 //===----------------------------------------------------------------------===//
798 // Pattern population
799 //===----------------------------------------------------------------------===//
800 
803  MLIRContext *ctx = patterns.getContext();
804 
806 
807  // clang-format off
808  patterns.add<
809  ArithConstantOpConversionPattern,
810  ArithOpConversion<arith::AddFOp, emitc::AddOp>,
811  ArithOpConversion<arith::DivFOp, emitc::DivOp>,
812  ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
813  ArithOpConversion<arith::MulFOp, emitc::MulOp>,
814  ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
815  ArithOpConversion<arith::SubFOp, emitc::SubOp>,
816  BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
817  BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
818  IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
819  IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
820  IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
821  BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
822  BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
823  BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
824  UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
825  SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
826  UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
827  CmpFOpConversion,
828  CmpIOpConversion,
829  NegFOpConversion,
830  SelectOpConversion,
831  // Truncation is guaranteed for unsigned types.
832  UnsignedCastConversion<arith::TruncIOp>,
833  SignedCastConversion<arith::ExtSIOp>,
834  UnsignedCastConversion<arith::ExtUIOp>,
835  SignedCastConversion<arith::IndexCastOp>,
836  UnsignedCastConversion<arith::IndexCastUIOp>,
837  ItoFCastOpConversion<arith::SIToFPOp>,
838  ItoFCastOpConversion<arith::UIToFPOp>,
839  FtoICastOpConversion<arith::FPToSIOp>,
840  FtoICastOpConversion<arith::FPToUIOp>,
841  FpCastOpConversion<arith::ExtFOp>,
842  FpCastOpConversion<arith::TruncFOp>
843  >(typeConverter, ctx);
844  // clang-format on
845 }
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:104
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:224
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:96
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:338
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class describes a specific conversion target.
ConvertToEmitCPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:443
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:396
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:518
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:682
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:500
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:76
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:64
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:117
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:135
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:96
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)