MLIR 22.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Arith dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
23
24using namespace mlir;
25
26namespace {
27/// Implement the interface to convert Arith to EmitC.
28struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
30
31 /// Hook for derived dialect interface to provide conversion patterns
32 /// and mark dialect legal for the conversion target.
33 void populateConvertToEmitCConversionPatterns(
34 ConversionTarget &target, TypeConverter &typeConverter,
35 RewritePatternSet &patterns) const final {
37 }
38};
39} // namespace
40
42 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43 dialect->addInterfaces<ArithToEmitCDialectInterface>();
44 });
45}
46
47//===----------------------------------------------------------------------===//
48// Conversion Patterns
49//===----------------------------------------------------------------------===//
50
51namespace {
52class ArithConstantOpConversionPattern
53 : public OpConversionPattern<arith::ConstantOp> {
54public:
55 using Base::Base;
56
57 LogicalResult
58 matchAndRewrite(arith::ConstantOp arithConst,
59 arith::ConstantOp::Adaptor adaptor,
60 ConversionPatternRewriter &rewriter) const override {
61 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62 if (!newTy)
63 return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65 adaptor.getValue());
66 return success();
67 }
68};
69
70/// Get the signed or unsigned type corresponding to \p ty.
71Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72 if (isa<IntegerType>(ty)) {
73 if (ty.isUnsignedInteger() != needsUnsigned) {
74 auto signedness = needsUnsigned
75 ? IntegerType::SignednessSemantics::Unsigned
76 : IntegerType::SignednessSemantics::Signed;
77 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
78 signedness);
79 }
80 } else if (emitc::isPointerWideType(ty)) {
81 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82 if (needsUnsigned)
83 return emitc::SizeTType::get(ty.getContext());
84 return emitc::PtrDiffTType::get(ty.getContext());
85 }
86 }
87 return ty;
88}
89
90/// Insert a cast operation to type \p ty if \p val does not have this type.
91Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93}
94
95class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96public:
97 using Base::Base;
98
99 LogicalResult
100 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101 ConversionPatternRewriter &rewriter) const override {
102
103 if (!isa<FloatType>(adaptor.getRhs().getType())) {
104 return rewriter.notifyMatchFailure(op.getLoc(),
105 "cmpf currently only supported on "
106 "floats, not tensors/vectors thereof");
107 }
108
109 bool unordered = false;
110 emitc::CmpPredicate predicate;
111 switch (op.getPredicate()) {
112 case arith::CmpFPredicate::AlwaysFalse: {
113 auto constant =
114 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
115 rewriter.getBoolAttr(/*value=*/false));
116 rewriter.replaceOp(op, constant);
117 return success();
118 }
119 case arith::CmpFPredicate::OEQ:
120 unordered = false;
121 predicate = emitc::CmpPredicate::eq;
122 break;
123 case arith::CmpFPredicate::OGT:
124 unordered = false;
125 predicate = emitc::CmpPredicate::gt;
126 break;
127 case arith::CmpFPredicate::OGE:
128 unordered = false;
129 predicate = emitc::CmpPredicate::ge;
130 break;
131 case arith::CmpFPredicate::OLT:
132 unordered = false;
133 predicate = emitc::CmpPredicate::lt;
134 break;
135 case arith::CmpFPredicate::OLE:
136 unordered = false;
137 predicate = emitc::CmpPredicate::le;
138 break;
139 case arith::CmpFPredicate::ONE:
140 unordered = false;
141 predicate = emitc::CmpPredicate::ne;
142 break;
143 case arith::CmpFPredicate::ORD: {
144 // ordered, i.e. none of the operands is NaN
145 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
146 adaptor.getRhs());
147 rewriter.replaceOp(op, cmp);
148 return success();
149 }
150 case arith::CmpFPredicate::UEQ:
151 unordered = true;
152 predicate = emitc::CmpPredicate::eq;
153 break;
154 case arith::CmpFPredicate::UGT:
155 unordered = true;
156 predicate = emitc::CmpPredicate::gt;
157 break;
158 case arith::CmpFPredicate::UGE:
159 unordered = true;
160 predicate = emitc::CmpPredicate::ge;
161 break;
162 case arith::CmpFPredicate::ULT:
163 unordered = true;
164 predicate = emitc::CmpPredicate::lt;
165 break;
166 case arith::CmpFPredicate::ULE:
167 unordered = true;
168 predicate = emitc::CmpPredicate::le;
169 break;
170 case arith::CmpFPredicate::UNE:
171 unordered = true;
172 predicate = emitc::CmpPredicate::ne;
173 break;
174 case arith::CmpFPredicate::UNO: {
175 // unordered, i.e. either operand is nan
176 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
177 adaptor.getRhs());
178 rewriter.replaceOp(op, cmp);
179 return success();
180 }
181 case arith::CmpFPredicate::AlwaysTrue: {
182 auto constant =
183 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
184 rewriter.getBoolAttr(/*value=*/true));
185 rewriter.replaceOp(op, constant);
186 return success();
187 }
188 }
189
190 // Compare the values naively
191 auto cmpResult =
192 emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
193 adaptor.getLhs(), adaptor.getRhs());
194
195 // Adjust the results for unordered/ordered semantics
196 if (unordered) {
197 auto isUnordered = createCheckIsUnordered(
198 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
199 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200 isUnordered, cmpResult);
201 return success();
202 }
203
204 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205 adaptor.getLhs(), adaptor.getRhs());
206 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207 isOrdered, cmpResult);
208 return success();
209 }
210
211private:
212 /// Return a value that is true if \p operand is NaN.
213 Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214 Value operand) const {
215 // A value is NaN exactly when it compares unequal to itself.
216 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
217 emitc::CmpPredicate::ne, operand, operand);
218 }
219
220 /// Return a value that is true if \p operand is not NaN.
221 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222 Value operand) const {
223 // A value is not NaN exactly when it compares equal to itself.
224 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
225 emitc::CmpPredicate::eq, operand, operand);
226 }
227
228 /// Return a value that is true if the operands \p first and \p second are
229 /// unordered (i.e., at least one of them is NaN).
230 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231 Location loc, Value first, Value second) const {
232 auto firstIsNaN = isNaN(rewriter, loc, first);
233 auto secondIsNaN = isNaN(rewriter, loc, second);
234 return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
235 firstIsNaN, secondIsNaN);
236 }
237
238 /// Return a value that is true if the operands \p first and \p second are
239 /// both ordered (i.e., none one of them is NaN).
240 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241 Value first, Value second) const {
242 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244 return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
245 firstIsNotNaN, secondIsNotNaN);
246 }
247};
248
249class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250public:
251 using Base::Base;
252
253 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254 switch (pred) {
255 case arith::CmpIPredicate::eq:
256 case arith::CmpIPredicate::ne:
257 case arith::CmpIPredicate::slt:
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::sgt:
260 case arith::CmpIPredicate::sge:
261 return false;
262 case arith::CmpIPredicate::ult:
263 case arith::CmpIPredicate::ule:
264 case arith::CmpIPredicate::ugt:
265 case arith::CmpIPredicate::uge:
266 return true;
267 }
268 llvm_unreachable("unknown cmpi predicate kind");
269 }
270
271 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272 switch (pred) {
273 case arith::CmpIPredicate::eq:
274 return emitc::CmpPredicate::eq;
275 case arith::CmpIPredicate::ne:
276 return emitc::CmpPredicate::ne;
277 case arith::CmpIPredicate::slt:
278 case arith::CmpIPredicate::ult:
279 return emitc::CmpPredicate::lt;
280 case arith::CmpIPredicate::sle:
281 case arith::CmpIPredicate::ule:
282 return emitc::CmpPredicate::le;
283 case arith::CmpIPredicate::sgt:
284 case arith::CmpIPredicate::ugt:
285 return emitc::CmpPredicate::gt;
286 case arith::CmpIPredicate::sge:
287 case arith::CmpIPredicate::uge:
288 return emitc::CmpPredicate::ge;
289 }
290 llvm_unreachable("unknown cmpi predicate kind");
291 }
292
293 LogicalResult
294 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296
297 Type type = adaptor.getLhs().getType();
298 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
299 return rewriter.notifyMatchFailure(
300 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301 }
302
303 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305
306 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309
310 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311 return success();
312 }
313};
314
315class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316public:
317 using Base::Base;
318
319 LogicalResult
320 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322
323 auto adaptedOp = adaptor.getOperand();
324 auto adaptedOpType = adaptedOp.getType();
325
326 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327 return rewriter.notifyMatchFailure(
328 op.getLoc(),
329 "negf currently only supports scalar types, not vectors or tensors");
330 }
331
332 if (!emitc::isSupportedFloatType(adaptedOpType)) {
333 return rewriter.notifyMatchFailure(
334 op.getLoc(), "floating-point type is not supported by EmitC");
335 }
336
337 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338 adaptedOp);
339 return success();
340 }
341};
342
343template <typename ArithOp, bool castToUnsigned>
344class CastConversion : public OpConversionPattern<ArithOp> {
345public:
346 using OpConversionPattern<ArithOp>::OpConversionPattern;
347
348 LogicalResult
349 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351
352 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
354 emitc::isPointerWideType(opReturnType)))
355 return rewriter.notifyMatchFailure(
356 op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357
358 if (adaptor.getOperands().size() != 1) {
359 return rewriter.notifyMatchFailure(
360 op, "CastConversion only supports unary ops");
361 }
362
363 Type operandType = adaptor.getIn().getType();
364 if (!operandType || !(isa<IntegerType>(operandType) ||
365 emitc::isPointerWideType(operandType)))
366 return rewriter.notifyMatchFailure(
367 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368
369 // Signed (sign-extending) casts from i1 are not supported.
370 if (operandType.isInteger(1) && !castToUnsigned)
371 return rewriter.notifyMatchFailure(op,
372 "operation not supported on i1 type");
373
374 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376 // truncation.
377 if (opReturnType.isInteger(1)) {
378 Type attrType = (emitc::isPointerWideType(operandType))
379 ? rewriter.getIndexType()
380 : operandType;
381 auto constOne = emitc::ConstantOp::create(
382 rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383 auto oneAndOperand = emitc::BitwiseAndOp::create(
384 rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
385 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386 oneAndOperand);
387 return success();
388 }
389
390 bool isTruncation =
391 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
392 operandType.getIntOrFloatBitWidth() >
393 opReturnType.getIntOrFloatBitWidth());
394 bool doUnsigned = castToUnsigned || isTruncation;
395
396 // Adapt the signedness of the result (bitwidth-preserving cast)
397 // This is needed e.g., if the return type is signless.
398 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
399
400 // Adapt the signedness of the operand (bitwidth-preserving cast)
401 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403
404 // Actual cast (may change bitwidth)
405 auto cast =
406 emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
407
408 // Cast to the expected output type
409 auto result = adaptValueType(cast, rewriter, opReturnType);
410
411 rewriter.replaceOp(op, result);
412 return success();
413 }
414};
415
416template <typename ArithOp>
417class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418 using CastConversion<ArithOp, true>::CastConversion;
419};
420
421template <typename ArithOp>
422class SignedCastConversion : public CastConversion<ArithOp, false> {
423 using CastConversion<ArithOp, false>::CastConversion;
424};
425
426template <typename ArithOp, typename EmitCOp>
427class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428public:
429 using OpConversionPattern<ArithOp>::OpConversionPattern;
430
431 LogicalResult
432 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433 ConversionPatternRewriter &rewriter) const override {
434
435 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436 if (!newTy)
437 return rewriter.notifyMatchFailure(arithOp,
438 "converting result type failed");
439 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440 adaptor.getOperands());
441
442 return success();
443 }
444};
445
446template <class ArithOp, class EmitCOp>
447class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448public:
449 using OpConversionPattern<ArithOp>::OpConversionPattern;
450
451 LogicalResult
452 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453 ConversionPatternRewriter &rewriter) const override {
454 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455 if (!newRetTy)
456 return rewriter.notifyMatchFailure(uiBinOp,
457 "converting result type failed");
458 if (!isa<IntegerType>(newRetTy)) {
459 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460 }
461 Type unsignedType =
462 adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
463 if (!unsignedType)
464 return rewriter.notifyMatchFailure(uiBinOp,
465 "converting result type failed");
466 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468
469 auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
470 ArrayRef<Value>{lhsAdapted, rhsAdapted});
471 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
472 rewriter.replaceOp(uiBinOp, resultAdapted);
473 return success();
474 }
475};
476
477template <typename ArithOp, typename EmitCOp>
478class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
479public:
480 using OpConversionPattern<ArithOp>::OpConversionPattern;
481
482 LogicalResult
483 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485
486 Type type = this->getTypeConverter()->convertType(op.getType());
487 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
488 return rewriter.notifyMatchFailure(
489 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
490 }
491
492 if (type.isInteger(1)) {
493 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
494 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
495 }
496
497 Type arithmeticType = type;
498 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
499 !bitEnumContainsAll(op.getOverflowFlags(),
500 arith::IntegerOverflowFlags::nsw)) {
501 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
502 // we compute in unsigned integers to avoid UB.
503 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
504 /*isSigned=*/false);
505 }
506
507 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
508 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
509
510 Value arithmeticResult =
511 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
512
513 Value result = adaptValueType(arithmeticResult, rewriter, type);
514
515 rewriter.replaceOp(op, result);
516 return success();
517 }
518};
519
520template <typename ArithOp, typename EmitCOp>
521class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
522public:
523 using OpConversionPattern<ArithOp>::OpConversionPattern;
524
525 LogicalResult
526 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
527 ConversionPatternRewriter &rewriter) const override {
528
529 Type type = this->getTypeConverter()->convertType(op.getType());
530 if (!isa_and_nonnull<IntegerType>(type)) {
531 return rewriter.notifyMatchFailure(
532 op,
533 "expected integer type, vector/tensor support not yet implemented");
534 }
535
536 // Bitwise ops can be performed directly on booleans
537 if (type.isInteger(1)) {
538 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
539 adaptor.getRhs());
540 return success();
541 }
542
543 // Bitwise ops are defined by the C standard on unsigned operands.
544 Type arithmeticType =
545 adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
546
547 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
548 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
549
550 Value arithmeticResult =
551 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
552
553 Value result = adaptValueType(arithmeticResult, rewriter, type);
554
555 rewriter.replaceOp(op, result);
556 return success();
557 }
558};
559
560template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
561class ShiftOpConversion : public OpConversionPattern<ArithOp> {
562public:
563 using OpConversionPattern<ArithOp>::OpConversionPattern;
564
565 LogicalResult
566 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
567 ConversionPatternRewriter &rewriter) const override {
568
569 Type type = this->getTypeConverter()->convertType(op.getType());
570 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
571 return rewriter.notifyMatchFailure(
572 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
573 }
574
575 if (type.isInteger(1)) {
576 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
577 }
578
579 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
580
581 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
582 // Shift amount interpreted as unsigned per Arith dialect spec.
583 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
584 /*needsUnsigned=*/true);
585 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
586
587 // Add a runtime check for overflow
588 Value width;
589 if (emitc::isPointerWideType(type)) {
590 Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
591 rewriter.getIndexAttr(8));
592 emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
593 rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
594 width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
595 sizeOfCall.getResult(0));
596 } else {
597 width = emitc::ConstantOp::create(
598 rewriter, op.getLoc(), rhsType,
599 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
600 }
601
602 Value excessCheck =
603 emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
604 emitc::CmpPredicate::lt, rhs, width);
605
606 // Any concrete value is a valid refinement of poison.
607 Value poison = emitc::ConstantOp::create(
608 rewriter, op.getLoc(), arithmeticType,
609 (isa<IntegerType>(arithmeticType)
610 ? rewriter.getIntegerAttr(arithmeticType, 0)
611 : rewriter.getIndexAttr(0)));
612
613 emitc::ExpressionOp ternary =
614 emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
615 ValueRange({lhs, rhs, excessCheck, poison}),
616 /*do_not_inline=*/false);
617 Block &bodyBlock = ternary.createBody();
618 auto currentPoint = rewriter.getInsertionPoint();
619 rewriter.setInsertionPointToStart(&bodyBlock);
620 Value arithmeticResult =
621 EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
622 bodyBlock.getArgument(0), bodyBlock.getArgument(1));
623 Value resultOrPoison = emitc::ConditionalOp::create(
624 rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
625 arithmeticResult, bodyBlock.getArgument(3));
626 emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
627 rewriter.setInsertionPoint(op->getBlock(), currentPoint);
628
629 Value result = adaptValueType(ternary, rewriter, type);
630
631 rewriter.replaceOp(op, result);
632 return success();
633 }
634};
635
636template <typename ArithOp, typename EmitCOp>
637class SignedShiftOpConversion final
638 : public ShiftOpConversion<ArithOp, EmitCOp, false> {
639 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
640};
641
642template <typename ArithOp, typename EmitCOp>
643class UnsignedShiftOpConversion final
644 : public ShiftOpConversion<ArithOp, EmitCOp, true> {
645 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
646};
647
648class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
649public:
650 using Base::Base;
651
652 LogicalResult
653 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655
656 Type dstType = getTypeConverter()->convertType(selectOp.getType());
657 if (!dstType)
658 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
659
660 if (!adaptor.getCondition().getType().isInteger(1))
661 return rewriter.notifyMatchFailure(
662 selectOp,
663 "can only be converted if condition is a scalar of type i1");
664
665 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
666 adaptor.getOperands());
667
668 return success();
669 }
670};
671
672// Floating-point to integer conversions.
673template <typename CastOp>
674class FtoICastOpConversion : public OpConversionPattern<CastOp> {
675public:
676 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
677 : OpConversionPattern<CastOp>(typeConverter, context) {}
678
679 LogicalResult
680 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter) const override {
682
683 Type operandType = adaptor.getIn().getType();
684 if (!emitc::isSupportedFloatType(operandType))
685 return rewriter.notifyMatchFailure(castOp,
686 "unsupported cast source type");
687
688 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
689 if (!dstType)
690 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
691
692 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
693 // truncated to 0, whereas a boolean conversion would return true.
694 if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
695 return rewriter.notifyMatchFailure(castOp,
696 "unsupported cast destination type");
697
698 // Convert to unsigned if it's the "ui" variant
699 // Signless is interpreted as signed, so no need to cast for "si"
700 Type actualResultType = dstType;
701 if (isa<arith::FPToUIOp>(castOp)) {
702 actualResultType =
703 rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
704 /*isSigned=*/false);
705 }
706
707 Value result = emitc::CastOp::create(
708 rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
709
710 if (isa<arith::FPToUIOp>(castOp)) {
711 result =
712 emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
713 }
714 rewriter.replaceOp(castOp, result);
715
716 return success();
717 }
718};
719
720// Integer to floating-point conversions.
721template <typename CastOp>
722class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
723public:
724 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
725 : OpConversionPattern<CastOp>(typeConverter, context) {}
726
727 LogicalResult
728 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
729 ConversionPatternRewriter &rewriter) const override {
730 // Vectors in particular are not supported
731 Type operandType = adaptor.getIn().getType();
732 if (!emitc::isSupportedIntegerType(operandType))
733 return rewriter.notifyMatchFailure(castOp,
734 "unsupported cast source type");
735
736 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
737 if (!dstType)
738 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
739
740 if (!emitc::isSupportedFloatType(dstType))
741 return rewriter.notifyMatchFailure(castOp,
742 "unsupported cast destination type");
743
744 // Convert to unsigned if it's the "ui" variant
745 // Signless is interpreted as signed, so no need to cast for "si"
746 Type actualOperandType = operandType;
747 if (isa<arith::UIToFPOp>(castOp)) {
748 actualOperandType =
749 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
750 /*isSigned=*/false);
751 }
752 Value fpCastOperand = adaptor.getIn();
753 if (actualOperandType != operandType) {
754 fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
755 actualOperandType, fpCastOperand);
756 }
757 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
758
759 return success();
760 }
761};
762
763// Floating-point to floating-point conversions.
764template <typename CastOp>
765class FpCastOpConversion : public OpConversionPattern<CastOp> {
766public:
767 FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
768 : OpConversionPattern<CastOp>(typeConverter, context) {}
769
770 LogicalResult
771 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
772 ConversionPatternRewriter &rewriter) const override {
773 // Vectors in particular are not supported.
774 Type operandType = adaptor.getIn().getType();
775 if (!emitc::isSupportedFloatType(operandType))
776 return rewriter.notifyMatchFailure(castOp,
777 "unsupported cast source type");
778 if (auto roundingModeOp =
779 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
780 // Only supporting default rounding mode as of now.
781 if (roundingModeOp.getRoundingModeAttr())
782 return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
783 }
784
785 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
786 if (!dstType)
787 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
788
789 if (!emitc::isSupportedFloatType(dstType))
790 return rewriter.notifyMatchFailure(castOp,
791 "unsupported cast destination type");
792
793 Value fpCastOperand = adaptor.getIn();
794 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
795
796 return success();
797 }
798};
799
800} // namespace
801
802//===----------------------------------------------------------------------===//
803// Pattern population
804//===----------------------------------------------------------------------===//
805
808 MLIRContext *ctx = patterns.getContext();
809
811
812 // clang-format off
813 patterns.add<
814 ArithConstantOpConversionPattern,
815 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
816 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
817 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
818 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
819 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
820 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
821 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
822 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
823 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
824 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
825 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
826 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
827 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
828 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
829 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
830 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
831 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
832 CmpFOpConversion,
833 CmpIOpConversion,
834 NegFOpConversion,
835 SelectOpConversion,
836 // Truncation is guaranteed for unsigned types.
837 UnsignedCastConversion<arith::TruncIOp>,
838 SignedCastConversion<arith::ExtSIOp>,
839 UnsignedCastConversion<arith::ExtUIOp>,
840 SignedCastConversion<arith::IndexCastOp>,
841 UnsignedCastConversion<arith::IndexCastUIOp>,
842 ItoFCastOpConversion<arith::SIToFPOp>,
843 ItoFCastOpConversion<arith::UIToFPOp>,
844 FtoICastOpConversion<arith::FPToSIOp>,
845 FtoICastOpConversion<arith::FPToUIOp>,
846 FpCastOpConversion<arith::ExtFOp>,
847 FpCastOpConversion<arith::TruncFOp>
848 >(typeConverter, ctx);
849 // clang-format on
850}
return success()
lhs
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:76
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:64
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:88
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:56
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:122
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition EmitC.cpp:116
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition EmitC.cpp:131
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition EmitC.cpp:95
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
const FrozenRewritePatternSet & patterns
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)