MLIR  20.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1 //===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements patterns to convert the Arith dialect to the EmitC
10 // dialect.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 
20 #include "mlir/IR/BuiltinTypes.h"
22 
23 using namespace mlir;
24 
25 //===----------------------------------------------------------------------===//
26 // Conversion Patterns
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 class ArithConstantOpConversionPattern
31  : public OpConversionPattern<arith::ConstantOp> {
32 public:
34 
35  LogicalResult
36  matchAndRewrite(arith::ConstantOp arithConst,
37  arith::ConstantOp::Adaptor adaptor,
38  ConversionPatternRewriter &rewriter) const override {
39  Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
40  if (!newTy)
41  return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
42  rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
43  adaptor.getValue());
44  return success();
45  }
46 };
47 
48 /// Get the signed or unsigned type corresponding to \p ty.
49 Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
50  if (isa<IntegerType>(ty)) {
51  if (ty.isUnsignedInteger() != needsUnsigned) {
52  auto signedness = needsUnsigned
53  ? IntegerType::SignednessSemantics::Unsigned
56  signedness);
57  }
58  } else if (emitc::isPointerWideType(ty)) {
59  if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
60  if (needsUnsigned)
61  return emitc::SizeTType::get(ty.getContext());
63  }
64  }
65  return ty;
66 }
67 
68 /// Insert a cast operation to type \p ty if \p val does not have this type.
69 Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
70  return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
71 }
72 
73 class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
74 public:
76 
77  LogicalResult
78  matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80 
81  if (!isa<FloatType>(adaptor.getRhs().getType())) {
82  return rewriter.notifyMatchFailure(op.getLoc(),
83  "cmpf currently only supported on "
84  "floats, not tensors/vectors thereof");
85  }
86 
87  bool unordered = false;
88  emitc::CmpPredicate predicate;
89  switch (op.getPredicate()) {
90  case arith::CmpFPredicate::AlwaysFalse: {
91  auto constant = rewriter.create<emitc::ConstantOp>(
92  op.getLoc(), rewriter.getI1Type(),
93  rewriter.getBoolAttr(/*value=*/false));
94  rewriter.replaceOp(op, constant);
95  return success();
96  }
97  case arith::CmpFPredicate::OEQ:
98  unordered = false;
99  predicate = emitc::CmpPredicate::eq;
100  break;
101  case arith::CmpFPredicate::OGT:
102  unordered = false;
103  predicate = emitc::CmpPredicate::gt;
104  break;
105  case arith::CmpFPredicate::OGE:
106  unordered = false;
107  predicate = emitc::CmpPredicate::ge;
108  break;
109  case arith::CmpFPredicate::OLT:
110  unordered = false;
111  predicate = emitc::CmpPredicate::lt;
112  break;
113  case arith::CmpFPredicate::OLE:
114  unordered = false;
115  predicate = emitc::CmpPredicate::le;
116  break;
117  case arith::CmpFPredicate::ONE:
118  unordered = false;
119  predicate = emitc::CmpPredicate::ne;
120  break;
121  case arith::CmpFPredicate::ORD: {
122  // ordered, i.e. none of the operands is NaN
123  auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
124  adaptor.getRhs());
125  rewriter.replaceOp(op, cmp);
126  return success();
127  }
128  case arith::CmpFPredicate::UEQ:
129  unordered = true;
130  predicate = emitc::CmpPredicate::eq;
131  break;
132  case arith::CmpFPredicate::UGT:
133  unordered = true;
134  predicate = emitc::CmpPredicate::gt;
135  break;
136  case arith::CmpFPredicate::UGE:
137  unordered = true;
138  predicate = emitc::CmpPredicate::ge;
139  break;
140  case arith::CmpFPredicate::ULT:
141  unordered = true;
142  predicate = emitc::CmpPredicate::lt;
143  break;
144  case arith::CmpFPredicate::ULE:
145  unordered = true;
146  predicate = emitc::CmpPredicate::le;
147  break;
148  case arith::CmpFPredicate::UNE:
149  unordered = true;
150  predicate = emitc::CmpPredicate::ne;
151  break;
152  case arith::CmpFPredicate::UNO: {
153  // unordered, i.e. either operand is nan
154  auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
155  adaptor.getRhs());
156  rewriter.replaceOp(op, cmp);
157  return success();
158  }
159  case arith::CmpFPredicate::AlwaysTrue: {
160  auto constant = rewriter.create<emitc::ConstantOp>(
161  op.getLoc(), rewriter.getI1Type(),
162  rewriter.getBoolAttr(/*value=*/true));
163  rewriter.replaceOp(op, constant);
164  return success();
165  }
166  }
167 
168  // Compare the values naively
169  auto cmpResult =
170  rewriter.create<emitc::CmpOp>(op.getLoc(), op.getType(), predicate,
171  adaptor.getLhs(), adaptor.getRhs());
172 
173  // Adjust the results for unordered/ordered semantics
174  if (unordered) {
175  auto isUnordered = createCheckIsUnordered(
176  rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
177  rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
178  isUnordered, cmpResult);
179  return success();
180  }
181 
182  auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
183  adaptor.getLhs(), adaptor.getRhs());
184  rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
185  isOrdered, cmpResult);
186  return success();
187  }
188 
189 private:
190  /// Return a value that is true if \p operand is NaN.
191  Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
192  Value operand) const {
193  // A value is NaN exactly when it compares unequal to itself.
194  return rewriter.create<emitc::CmpOp>(
195  loc, rewriter.getI1Type(), emitc::CmpPredicate::ne, operand, operand);
196  }
197 
198  /// Return a value that is true if \p operand is not NaN.
199  Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
200  Value operand) const {
201  // A value is not NaN exactly when it compares equal to itself.
202  return rewriter.create<emitc::CmpOp>(
203  loc, rewriter.getI1Type(), emitc::CmpPredicate::eq, operand, operand);
204  }
205 
206  /// Return a value that is true if the operands \p first and \p second are
207  /// unordered (i.e., at least one of them is NaN).
208  Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
209  Location loc, Value first, Value second) const {
210  auto firstIsNaN = isNaN(rewriter, loc, first);
211  auto secondIsNaN = isNaN(rewriter, loc, second);
212  return rewriter.create<emitc::LogicalOrOp>(loc, rewriter.getI1Type(),
213  firstIsNaN, secondIsNaN);
214  }
215 
216  /// Return a value that is true if the operands \p first and \p second are
217  /// both ordered (i.e., none one of them is NaN).
218  Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
219  Value first, Value second) const {
220  auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
221  auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
222  return rewriter.create<emitc::LogicalAndOp>(loc, rewriter.getI1Type(),
223  firstIsNotNaN, secondIsNotNaN);
224  }
225 };
226 
227 class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
228 public:
230 
231  bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
232  switch (pred) {
233  case arith::CmpIPredicate::eq:
234  case arith::CmpIPredicate::ne:
235  case arith::CmpIPredicate::slt:
236  case arith::CmpIPredicate::sle:
237  case arith::CmpIPredicate::sgt:
238  case arith::CmpIPredicate::sge:
239  return false;
240  case arith::CmpIPredicate::ult:
241  case arith::CmpIPredicate::ule:
242  case arith::CmpIPredicate::ugt:
243  case arith::CmpIPredicate::uge:
244  return true;
245  }
246  llvm_unreachable("unknown cmpi predicate kind");
247  }
248 
249  emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
250  switch (pred) {
251  case arith::CmpIPredicate::eq:
252  return emitc::CmpPredicate::eq;
253  case arith::CmpIPredicate::ne:
254  return emitc::CmpPredicate::ne;
255  case arith::CmpIPredicate::slt:
256  case arith::CmpIPredicate::ult:
257  return emitc::CmpPredicate::lt;
258  case arith::CmpIPredicate::sle:
259  case arith::CmpIPredicate::ule:
260  return emitc::CmpPredicate::le;
261  case arith::CmpIPredicate::sgt:
262  case arith::CmpIPredicate::ugt:
263  return emitc::CmpPredicate::gt;
264  case arith::CmpIPredicate::sge:
265  case arith::CmpIPredicate::uge:
266  return emitc::CmpPredicate::ge;
267  }
268  llvm_unreachable("unknown cmpi predicate kind");
269  }
270 
271  LogicalResult
272  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
273  ConversionPatternRewriter &rewriter) const override {
274 
275  Type type = adaptor.getLhs().getType();
276  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
277  return rewriter.notifyMatchFailure(
278  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
279  }
280 
281  bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
282  emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
283 
284  Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
285  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
286  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
287 
288  rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
289  return success();
290  }
291 };
292 
293 class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
294 public:
296 
297  LogicalResult
298  matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
299  ConversionPatternRewriter &rewriter) const override {
300 
301  auto adaptedOp = adaptor.getOperand();
302  auto adaptedOpType = adaptedOp.getType();
303 
304  if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
305  return rewriter.notifyMatchFailure(
306  op.getLoc(),
307  "negf currently only supports scalar types, not vectors or tensors");
308  }
309 
310  if (!emitc::isSupportedFloatType(adaptedOpType)) {
311  return rewriter.notifyMatchFailure(
312  op.getLoc(), "floating-point type is not supported by EmitC");
313  }
314 
315  rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
316  adaptedOp);
317  return success();
318  }
319 };
320 
321 template <typename ArithOp, bool castToUnsigned>
322 class CastConversion : public OpConversionPattern<ArithOp> {
323 public:
325 
326  LogicalResult
327  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
328  ConversionPatternRewriter &rewriter) const override {
329 
330  Type opReturnType = this->getTypeConverter()->convertType(op.getType());
331  if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
332  emitc::isPointerWideType(opReturnType)))
333  return rewriter.notifyMatchFailure(
334  op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
335 
336  if (adaptor.getOperands().size() != 1) {
337  return rewriter.notifyMatchFailure(
338  op, "CastConversion only supports unary ops");
339  }
340 
341  Type operandType = adaptor.getIn().getType();
342  if (!operandType || !(isa<IntegerType>(operandType) ||
343  emitc::isPointerWideType(operandType)))
344  return rewriter.notifyMatchFailure(
345  op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
346 
347  // Signed (sign-extending) casts from i1 are not supported.
348  if (operandType.isInteger(1) && !castToUnsigned)
349  return rewriter.notifyMatchFailure(op,
350  "operation not supported on i1 type");
351 
352  // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
353  // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
354  // truncation.
355  if (opReturnType.isInteger(1)) {
356  Type attrType = (emitc::isPointerWideType(operandType))
357  ? rewriter.getIndexType()
358  : operandType;
359  auto constOne = rewriter.create<emitc::ConstantOp>(
360  op.getLoc(), operandType, rewriter.getOneAttr(attrType));
361  auto oneAndOperand = rewriter.create<emitc::BitwiseAndOp>(
362  op.getLoc(), operandType, adaptor.getIn(), constOne);
363  rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
364  oneAndOperand);
365  return success();
366  }
367 
368  bool isTruncation =
369  (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
370  operandType.getIntOrFloatBitWidth() >
371  opReturnType.getIntOrFloatBitWidth());
372  bool doUnsigned = castToUnsigned || isTruncation;
373 
374  // Adapt the signedness of the result (bitwidth-preserving cast)
375  // This is needed e.g., if the return type is signless.
376  Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
377 
378  // Adapt the signedness of the operand (bitwidth-preserving cast)
379  Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
380  Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
381 
382  // Actual cast (may change bitwidth)
383  auto cast = rewriter.template create<emitc::CastOp>(op.getLoc(),
384  castDestType, actualOp);
385 
386  // Cast to the expected output type
387  auto result = adaptValueType(cast, rewriter, opReturnType);
388 
389  rewriter.replaceOp(op, result);
390  return success();
391  }
392 };
393 
394 template <typename ArithOp>
395 class UnsignedCastConversion : public CastConversion<ArithOp, true> {
396  using CastConversion<ArithOp, true>::CastConversion;
397 };
398 
399 template <typename ArithOp>
400 class SignedCastConversion : public CastConversion<ArithOp, false> {
401  using CastConversion<ArithOp, false>::CastConversion;
402 };
403 
404 template <typename ArithOp, typename EmitCOp>
405 class ArithOpConversion final : public OpConversionPattern<ArithOp> {
406 public:
408 
409  LogicalResult
410  matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
411  ConversionPatternRewriter &rewriter) const override {
412 
413  Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
414  if (!newTy)
415  return rewriter.notifyMatchFailure(arithOp,
416  "converting result type failed");
417  rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
418  adaptor.getOperands());
419 
420  return success();
421  }
422 };
423 
424 template <class ArithOp, class EmitCOp>
425 class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
426 public:
428 
429  LogicalResult
430  matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
431  ConversionPatternRewriter &rewriter) const override {
432  Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
433  if (!newRetTy)
434  return rewriter.notifyMatchFailure(uiBinOp,
435  "converting result type failed");
436  if (!isa<IntegerType>(newRetTy)) {
437  return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
438  }
439  Type unsignedType =
440  adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
441  if (!unsignedType)
442  return rewriter.notifyMatchFailure(uiBinOp,
443  "converting result type failed");
444  Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
445  Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
446 
447  auto newDivOp =
448  rewriter.create<EmitCOp>(uiBinOp.getLoc(), unsignedType,
449  ArrayRef<Value>{lhsAdapted, rhsAdapted});
450  Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
451  rewriter.replaceOp(uiBinOp, resultAdapted);
452  return success();
453  }
454 };
455 
456 template <typename ArithOp, typename EmitCOp>
457 class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
458 public:
460 
461  LogicalResult
462  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
463  ConversionPatternRewriter &rewriter) const override {
464 
465  Type type = this->getTypeConverter()->convertType(op.getType());
466  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
467  return rewriter.notifyMatchFailure(
468  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
469  }
470 
471  if (type.isInteger(1)) {
472  // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
473  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
474  }
475 
476  Type arithmeticType = type;
477  if ((type.isSignlessInteger() || type.isSignedInteger()) &&
478  !bitEnumContainsAll(op.getOverflowFlags(),
479  arith::IntegerOverflowFlags::nsw)) {
480  // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
481  // we compute in unsigned integers to avoid UB.
482  arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
483  /*isSigned=*/false);
484  }
485 
486  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
487  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
488 
489  Value arithmeticResult = rewriter.template create<EmitCOp>(
490  op.getLoc(), arithmeticType, lhs, rhs);
491 
492  Value result = adaptValueType(arithmeticResult, rewriter, type);
493 
494  rewriter.replaceOp(op, result);
495  return success();
496  }
497 };
498 
499 template <typename ArithOp, typename EmitCOp>
500 class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
501 public:
503 
504  LogicalResult
505  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
506  ConversionPatternRewriter &rewriter) const override {
507 
508  Type type = this->getTypeConverter()->convertType(op.getType());
509  if (!isa_and_nonnull<IntegerType>(type)) {
510  return rewriter.notifyMatchFailure(
511  op,
512  "expected integer type, vector/tensor support not yet implemented");
513  }
514 
515  // Bitwise ops can be performed directly on booleans
516  if (type.isInteger(1)) {
517  rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
518  adaptor.getRhs());
519  return success();
520  }
521 
522  // Bitwise ops are defined by the C standard on unsigned operands.
523  Type arithmeticType =
524  adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
525 
526  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
527  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
528 
529  Value arithmeticResult = rewriter.template create<EmitCOp>(
530  op.getLoc(), arithmeticType, lhs, rhs);
531 
532  Value result = adaptValueType(arithmeticResult, rewriter, type);
533 
534  rewriter.replaceOp(op, result);
535  return success();
536  }
537 };
538 
539 template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
540 class ShiftOpConversion : public OpConversionPattern<ArithOp> {
541 public:
543 
544  LogicalResult
545  matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
546  ConversionPatternRewriter &rewriter) const override {
547 
548  Type type = this->getTypeConverter()->convertType(op.getType());
549  if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
550  return rewriter.notifyMatchFailure(
551  op, "expected integer or size_t/ssize_t/ptrdiff_t type");
552  }
553 
554  if (type.isInteger(1)) {
555  return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
556  }
557 
558  Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
559 
560  Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
561  // Shift amount interpreted as unsigned per Arith dialect spec.
562  Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
563  /*needsUnsigned=*/true);
564  Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
565 
566  // Add a runtime check for overflow
567  Value width;
568  if (emitc::isPointerWideType(type)) {
569  Value eight = rewriter.create<emitc::ConstantOp>(
570  op.getLoc(), rhsType, rewriter.getIndexAttr(8));
571  emitc::CallOpaqueOp sizeOfCall = rewriter.create<emitc::CallOpaqueOp>(
572  op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
573  width = rewriter.create<emitc::MulOp>(op.getLoc(), rhsType, eight,
574  sizeOfCall.getResult(0));
575  } else {
576  width = rewriter.create<emitc::ConstantOp>(
577  op.getLoc(), rhsType,
578  rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
579  }
580 
581  Value excessCheck = rewriter.create<emitc::CmpOp>(
582  op.getLoc(), rewriter.getI1Type(), emitc::CmpPredicate::lt, rhs, width);
583 
584  // Any concrete value is a valid refinement of poison.
585  Value poison = rewriter.create<emitc::ConstantOp>(
586  op.getLoc(), arithmeticType,
587  (isa<IntegerType>(arithmeticType)
588  ? rewriter.getIntegerAttr(arithmeticType, 0)
589  : rewriter.getIndexAttr(0)));
590 
591  emitc::ExpressionOp ternary = rewriter.create<emitc::ExpressionOp>(
592  op.getLoc(), arithmeticType, /*do_not_inline=*/false);
593  Block &bodyBlock = ternary.getBodyRegion().emplaceBlock();
594  auto currentPoint = rewriter.getInsertionPoint();
595  rewriter.setInsertionPointToStart(&bodyBlock);
596  Value arithmeticResult =
597  rewriter.create<EmitCOp>(op.getLoc(), arithmeticType, lhs, rhs);
598  Value resultOrPoison = rewriter.create<emitc::ConditionalOp>(
599  op.getLoc(), arithmeticType, excessCheck, arithmeticResult, poison);
600  rewriter.create<emitc::YieldOp>(op.getLoc(), resultOrPoison);
601  rewriter.setInsertionPoint(op->getBlock(), currentPoint);
602 
603  Value result = adaptValueType(ternary, rewriter, type);
604 
605  rewriter.replaceOp(op, result);
606  return success();
607  }
608 };
609 
610 template <typename ArithOp, typename EmitCOp>
611 class SignedShiftOpConversion final
612  : public ShiftOpConversion<ArithOp, EmitCOp, false> {
613  using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
614 };
615 
616 template <typename ArithOp, typename EmitCOp>
617 class UnsignedShiftOpConversion final
618  : public ShiftOpConversion<ArithOp, EmitCOp, true> {
619  using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
620 };
621 
622 class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
623 public:
625 
626  LogicalResult
627  matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
628  ConversionPatternRewriter &rewriter) const override {
629 
630  Type dstType = getTypeConverter()->convertType(selectOp.getType());
631  if (!dstType)
632  return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
633 
634  if (!adaptor.getCondition().getType().isInteger(1))
635  return rewriter.notifyMatchFailure(
636  selectOp,
637  "can only be converted if condition is a scalar of type i1");
638 
639  rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
640  adaptor.getOperands());
641 
642  return success();
643  }
644 };
645 
646 // Floating-point to integer conversions.
647 template <typename CastOp>
648 class FtoICastOpConversion : public OpConversionPattern<CastOp> {
649 public:
650  FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
651  : OpConversionPattern<CastOp>(typeConverter, context) {}
652 
653  LogicalResult
654  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
655  ConversionPatternRewriter &rewriter) const override {
656 
657  Type operandType = adaptor.getIn().getType();
658  if (!emitc::isSupportedFloatType(operandType))
659  return rewriter.notifyMatchFailure(castOp,
660  "unsupported cast source type");
661 
662  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
663  if (!dstType)
664  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
665 
666  // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
667  // truncated to 0, whereas a boolean conversion would return true.
668  if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
669  return rewriter.notifyMatchFailure(castOp,
670  "unsupported cast destination type");
671 
672  // Convert to unsigned if it's the "ui" variant
673  // Signless is interpreted as signed, so no need to cast for "si"
674  Type actualResultType = dstType;
675  if (isa<arith::FPToUIOp>(castOp)) {
676  actualResultType =
677  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
678  /*isSigned=*/false);
679  }
680 
681  Value result = rewriter.create<emitc::CastOp>(
682  castOp.getLoc(), actualResultType, adaptor.getOperands());
683 
684  if (isa<arith::FPToUIOp>(castOp)) {
685  result = rewriter.create<emitc::CastOp>(castOp.getLoc(), dstType, result);
686  }
687  rewriter.replaceOp(castOp, result);
688 
689  return success();
690  }
691 };
692 
693 // Integer to floating-point conversions.
694 template <typename CastOp>
695 class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
696 public:
697  ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
698  : OpConversionPattern<CastOp>(typeConverter, context) {}
699 
700  LogicalResult
701  matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
702  ConversionPatternRewriter &rewriter) const override {
703  // Vectors in particular are not supported
704  Type operandType = adaptor.getIn().getType();
705  if (!emitc::isSupportedIntegerType(operandType))
706  return rewriter.notifyMatchFailure(castOp,
707  "unsupported cast source type");
708 
709  Type dstType = this->getTypeConverter()->convertType(castOp.getType());
710  if (!dstType)
711  return rewriter.notifyMatchFailure(castOp, "type conversion failed");
712 
713  if (!emitc::isSupportedFloatType(dstType))
714  return rewriter.notifyMatchFailure(castOp,
715  "unsupported cast destination type");
716 
717  // Convert to unsigned if it's the "ui" variant
718  // Signless is interpreted as signed, so no need to cast for "si"
719  Type actualOperandType = operandType;
720  if (isa<arith::UIToFPOp>(castOp)) {
721  actualOperandType =
722  rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
723  /*isSigned=*/false);
724  }
725  Value fpCastOperand = adaptor.getIn();
726  if (actualOperandType != operandType) {
727  fpCastOperand = rewriter.template create<emitc::CastOp>(
728  castOp.getLoc(), actualOperandType, fpCastOperand);
729  }
730  rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
731 
732  return success();
733  }
734 };
735 
736 } // namespace
737 
738 //===----------------------------------------------------------------------===//
739 // Pattern population
740 //===----------------------------------------------------------------------===//
741 
743  RewritePatternSet &patterns) {
744  MLIRContext *ctx = patterns.getContext();
745 
747 
748  // clang-format off
749  patterns.add<
750  ArithConstantOpConversionPattern,
751  ArithOpConversion<arith::AddFOp, emitc::AddOp>,
752  ArithOpConversion<arith::DivFOp, emitc::DivOp>,
753  ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
754  ArithOpConversion<arith::MulFOp, emitc::MulOp>,
755  ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
756  ArithOpConversion<arith::SubFOp, emitc::SubOp>,
757  BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
758  BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
759  IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
760  IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
761  IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
762  BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
763  BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
764  BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
765  UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
766  SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
767  UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
768  CmpFOpConversion,
769  CmpIOpConversion,
770  NegFOpConversion,
771  SelectOpConversion,
772  // Truncation is guaranteed for unsigned types.
773  UnsignedCastConversion<arith::TruncIOp>,
774  SignedCastConversion<arith::ExtSIOp>,
775  UnsignedCastConversion<arith::ExtUIOp>,
776  SignedCastConversion<arith::IndexCastOp>,
777  UnsignedCastConversion<arith::IndexCastUIOp>,
778  ItoFCastOpConversion<arith::SIToFPOp>,
779  ItoFCastOpConversion<arith::UIToFPOp>,
780  FtoICastOpConversion<arith::FPToSIOp>,
781  FtoICastOpConversion<arith::FPToUIOp>
782  >(typeConverter, ctx);
783  // clang-format on
784 }
Block represents an ordered list of Operations.
Definition: Block.h:33
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:148
IntegerAttr getIntegerAttr(Type type, int64_t value)
Definition: Builders.cpp:268
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
BoolAttr getBoolAttr(bool value)
Definition: Builders.cpp:140
IntegerType getI1Type()
Definition: Builders.cpp:97
IndexType getIndexType()
Definition: Builders.cpp:95
TypedAttr getOneAttr(Type type)
Definition: Builders.cpp:382
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition: Builders.h:453
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:406
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition: Builders.h:528
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
MLIRContext * getContext() const
Definition: PatternMatch.h:829
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Type conversion class.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition: Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition: Types.cpp:87
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition: Types.cpp:75
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition: Types.cpp:99
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:66
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Location getLoc() const
Return the location of this value.
Definition: Value.cpp:26
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:116
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:134
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:95
CmpPredicate
Copy of the enum from arith and index to allow the common integer range infrastructure to not depend ...
Include the generated interface declarations.
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)