MLIR 23.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Arith dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
23
24using namespace mlir;
25
26namespace {
27/// Implement the interface to convert Arith to EmitC.
28struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
30
31 /// Hook for derived dialect interface to provide conversion patterns
32 /// and mark dialect legal for the conversion target.
33 void populateConvertToEmitCConversionPatterns(
34 ConversionTarget &target, TypeConverter &typeConverter,
35 RewritePatternSet &patterns) const final {
36 populateArithToEmitCPatterns(typeConverter, patterns);
37 }
38};
39} // namespace
40
42 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
43 dialect->addInterfaces<ArithToEmitCDialectInterface>();
44 });
45}
46
47//===----------------------------------------------------------------------===//
48// Conversion Patterns
49//===----------------------------------------------------------------------===//
50
51namespace {
52class ArithConstantOpConversionPattern
53 : public OpConversionPattern<arith::ConstantOp> {
54public:
55 using Base::Base;
56
57 LogicalResult
58 matchAndRewrite(arith::ConstantOp arithConst,
59 arith::ConstantOp::Adaptor adaptor,
60 ConversionPatternRewriter &rewriter) const override {
61 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
62 if (!newTy)
63 return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
64 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
65 adaptor.getValue());
66 return success();
67 }
68};
69
70/// Get the signed or unsigned type corresponding to \p ty.
71Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
72 if (isa<IntegerType>(ty)) {
73 if (ty.isUnsignedInteger() != needsUnsigned) {
74 auto signedness = needsUnsigned
75 ? IntegerType::SignednessSemantics::Unsigned
76 : IntegerType::SignednessSemantics::Signed;
77 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
78 signedness);
79 }
80 } else if (emitc::isPointerWideType(ty)) {
81 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
82 if (needsUnsigned)
83 return emitc::SizeTType::get(ty.getContext());
84 return emitc::PtrDiffTType::get(ty.getContext());
85 }
86 }
87 return ty;
88}
89
90/// Insert a cast operation to type \p ty if \p val does not have this type.
91Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
92 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
93}
94
95class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
96public:
97 using Base::Base;
98
99 LogicalResult
100 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
101 ConversionPatternRewriter &rewriter) const override {
102
103 if (!isa<FloatType>(adaptor.getRhs().getType())) {
104 return rewriter.notifyMatchFailure(op.getLoc(),
105 "cmpf currently only supported on "
106 "floats, not tensors/vectors thereof");
107 }
108
109 bool unordered = false;
110 emitc::CmpPredicate predicate;
111 switch (op.getPredicate()) {
112 case arith::CmpFPredicate::AlwaysFalse: {
113 auto constant =
114 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
115 rewriter.getBoolAttr(/*value=*/false));
116 rewriter.replaceOp(op, constant);
117 return success();
118 }
119 case arith::CmpFPredicate::OEQ:
120 unordered = false;
121 predicate = emitc::CmpPredicate::eq;
122 break;
123 case arith::CmpFPredicate::OGT:
124 unordered = false;
125 predicate = emitc::CmpPredicate::gt;
126 break;
127 case arith::CmpFPredicate::OGE:
128 unordered = false;
129 predicate = emitc::CmpPredicate::ge;
130 break;
131 case arith::CmpFPredicate::OLT:
132 unordered = false;
133 predicate = emitc::CmpPredicate::lt;
134 break;
135 case arith::CmpFPredicate::OLE:
136 unordered = false;
137 predicate = emitc::CmpPredicate::le;
138 break;
139 case arith::CmpFPredicate::ONE:
140 unordered = false;
141 predicate = emitc::CmpPredicate::ne;
142 break;
143 case arith::CmpFPredicate::ORD: {
144 // ordered, i.e. none of the operands is NaN
145 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
146 adaptor.getRhs());
147 rewriter.replaceOp(op, cmp);
148 return success();
149 }
150 case arith::CmpFPredicate::UEQ:
151 unordered = true;
152 predicate = emitc::CmpPredicate::eq;
153 break;
154 case arith::CmpFPredicate::UGT:
155 unordered = true;
156 predicate = emitc::CmpPredicate::gt;
157 break;
158 case arith::CmpFPredicate::UGE:
159 unordered = true;
160 predicate = emitc::CmpPredicate::ge;
161 break;
162 case arith::CmpFPredicate::ULT:
163 unordered = true;
164 predicate = emitc::CmpPredicate::lt;
165 break;
166 case arith::CmpFPredicate::ULE:
167 unordered = true;
168 predicate = emitc::CmpPredicate::le;
169 break;
170 case arith::CmpFPredicate::UNE:
171 unordered = true;
172 predicate = emitc::CmpPredicate::ne;
173 break;
174 case arith::CmpFPredicate::UNO: {
175 // unordered, i.e. either operand is nan
176 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
177 adaptor.getRhs());
178 rewriter.replaceOp(op, cmp);
179 return success();
180 }
181 case arith::CmpFPredicate::AlwaysTrue: {
182 auto constant =
183 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
184 rewriter.getBoolAttr(/*value=*/true));
185 rewriter.replaceOp(op, constant);
186 return success();
187 }
188 }
189
190 // Compare the values naively
191 auto cmpResult =
192 emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
193 adaptor.getLhs(), adaptor.getRhs());
194
195 // Adjust the results for unordered/ordered semantics
196 if (unordered) {
197 auto isUnordered = createCheckIsUnordered(
198 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
199 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
200 isUnordered, cmpResult);
201 return success();
202 }
203
204 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
205 adaptor.getLhs(), adaptor.getRhs());
206 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
207 isOrdered, cmpResult);
208 return success();
209 }
210
211private:
212 /// Return a value that is true if \p operand is NaN.
213 Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
214 Value operand) const {
215 // A value is NaN exactly when it compares unequal to itself.
216 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
217 emitc::CmpPredicate::ne, operand, operand);
218 }
219
220 /// Return a value that is true if \p operand is not NaN.
221 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
222 Value operand) const {
223 // A value is not NaN exactly when it compares equal to itself.
224 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
225 emitc::CmpPredicate::eq, operand, operand);
226 }
227
228 /// Return a value that is true if the operands \p first and \p second are
229 /// unordered (i.e., at least one of them is NaN).
230 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
231 Location loc, Value first, Value second) const {
232 auto firstIsNaN = isNaN(rewriter, loc, first);
233 auto secondIsNaN = isNaN(rewriter, loc, second);
234 return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
235 firstIsNaN, secondIsNaN);
236 }
237
238 /// Return a value that is true if the operands \p first and \p second are
239 /// both ordered (i.e., none one of them is NaN).
240 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
241 Value first, Value second) const {
242 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
243 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
244 return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
245 firstIsNotNaN, secondIsNotNaN);
246 }
247};
248
249class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
250public:
251 using Base::Base;
252
253 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
254 switch (pred) {
255 case arith::CmpIPredicate::eq:
256 case arith::CmpIPredicate::ne:
257 case arith::CmpIPredicate::slt:
258 case arith::CmpIPredicate::sle:
259 case arith::CmpIPredicate::sgt:
260 case arith::CmpIPredicate::sge:
261 return false;
262 case arith::CmpIPredicate::ult:
263 case arith::CmpIPredicate::ule:
264 case arith::CmpIPredicate::ugt:
265 case arith::CmpIPredicate::uge:
266 return true;
267 }
268 llvm_unreachable("unknown cmpi predicate kind");
269 }
270
271 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
272 switch (pred) {
273 case arith::CmpIPredicate::eq:
274 return emitc::CmpPredicate::eq;
275 case arith::CmpIPredicate::ne:
276 return emitc::CmpPredicate::ne;
277 case arith::CmpIPredicate::slt:
278 case arith::CmpIPredicate::ult:
279 return emitc::CmpPredicate::lt;
280 case arith::CmpIPredicate::sle:
281 case arith::CmpIPredicate::ule:
282 return emitc::CmpPredicate::le;
283 case arith::CmpIPredicate::sgt:
284 case arith::CmpIPredicate::ugt:
285 return emitc::CmpPredicate::gt;
286 case arith::CmpIPredicate::sge:
287 case arith::CmpIPredicate::uge:
288 return emitc::CmpPredicate::ge;
289 }
290 llvm_unreachable("unknown cmpi predicate kind");
291 }
292
293 LogicalResult
294 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296
297 Type type = adaptor.getLhs().getType();
298 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
299 return rewriter.notifyMatchFailure(
300 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
301 }
302
303 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
304 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
305
306 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
307 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
308 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
309
310 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
311 return success();
312 }
313};
314
315class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
316public:
317 using Base::Base;
318
319 LogicalResult
320 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
321 ConversionPatternRewriter &rewriter) const override {
322
323 auto adaptedOp = adaptor.getOperand();
324 auto adaptedOpType = adaptedOp.getType();
325
326 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
327 return rewriter.notifyMatchFailure(
328 op.getLoc(),
329 "negf currently only supports scalar types, not vectors or tensors");
330 }
331
332 if (!emitc::isSupportedFloatType(adaptedOpType)) {
333 return rewriter.notifyMatchFailure(
334 op.getLoc(), "floating-point type is not supported by EmitC");
335 }
336
337 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
338 adaptedOp);
339 return success();
340 }
341};
342
343template <typename ArithOp, bool castToUnsigned>
344class CastConversion : public OpConversionPattern<ArithOp> {
345public:
346 using OpConversionPattern<ArithOp>::OpConversionPattern;
347
348 LogicalResult
349 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
350 ConversionPatternRewriter &rewriter) const override {
351
352 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
353 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
354 emitc::isPointerWideType(opReturnType)))
355 return rewriter.notifyMatchFailure(
356 op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
357
358 if (adaptor.getOperands().size() != 1) {
359 return rewriter.notifyMatchFailure(
360 op, "CastConversion only supports unary ops");
361 }
362
363 Type operandType = adaptor.getIn().getType();
364 if (!operandType || !(isa<IntegerType>(operandType) ||
365 emitc::isPointerWideType(operandType)))
366 return rewriter.notifyMatchFailure(
367 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
368
369 // Signed (sign-extending) casts from i1 are not supported.
370 if (operandType.isInteger(1) && !castToUnsigned)
371 return rewriter.notifyMatchFailure(op,
372 "operation not supported on i1 type");
373
374 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
375 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
376 // truncation.
377 if (opReturnType.isInteger(1)) {
378 Type attrType = (emitc::isPointerWideType(operandType))
379 ? rewriter.getIndexType()
380 : operandType;
381 auto constOne = emitc::ConstantOp::create(
382 rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
383 auto oneAndOperand = emitc::BitwiseAndOp::create(
384 rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
385 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
386 oneAndOperand);
387 return success();
388 }
389
390 bool isTruncation =
391 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
392 operandType.getIntOrFloatBitWidth() >
393 opReturnType.getIntOrFloatBitWidth());
394 bool doUnsigned = castToUnsigned || isTruncation;
395
396 // Adapt the signedness of the result (bitwidth-preserving cast)
397 // This is needed e.g., if the return type is signless.
398 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
399
400 // Adapt the signedness of the operand (bitwidth-preserving cast)
401 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
402 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
403
404 // Actual cast (may change bitwidth)
405 auto cast =
406 emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
407
408 // Cast to the expected output type
409 auto result = adaptValueType(cast, rewriter, opReturnType);
410
411 rewriter.replaceOp(op, result);
412 return success();
413 }
414};
415
416template <typename ArithOp>
417class UnsignedCastConversion : public CastConversion<ArithOp, true> {
418 using CastConversion<ArithOp, true>::CastConversion;
419};
420
421template <typename ArithOp>
422class SignedCastConversion : public CastConversion<ArithOp, false> {
423 using CastConversion<ArithOp, false>::CastConversion;
424};
425
426template <typename ArithOp, typename EmitCOp>
427class ArithOpConversion final : public OpConversionPattern<ArithOp> {
428public:
429 using OpConversionPattern<ArithOp>::OpConversionPattern;
430
431 LogicalResult
432 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
433 ConversionPatternRewriter &rewriter) const override {
434
435 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
436 if (!newTy)
437 return rewriter.notifyMatchFailure(arithOp,
438 "converting result type failed");
439 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
440 adaptor.getOperands());
441
442 return success();
443 }
444};
445
446template <class ArithOp, class EmitCOp>
447class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
448public:
449 using OpConversionPattern<ArithOp>::OpConversionPattern;
450
451 LogicalResult
452 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
453 ConversionPatternRewriter &rewriter) const override {
454 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
455 if (!newRetTy)
456 return rewriter.notifyMatchFailure(uiBinOp,
457 "converting result type failed");
458 if (!isa<IntegerType>(newRetTy)) {
459 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
460 }
461 Type unsignedType =
462 adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
463 if (!unsignedType)
464 return rewriter.notifyMatchFailure(uiBinOp,
465 "converting result type failed");
466 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
467 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
468
469 auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
470 ArrayRef<Value>{lhsAdapted, rhsAdapted});
471 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
472 rewriter.replaceOp(uiBinOp, resultAdapted);
473 return success();
474 }
475};
476
477template <typename ArithOp, typename EmitCOp>
478class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
479public:
480 using OpConversionPattern<ArithOp>::OpConversionPattern;
481
482 LogicalResult
483 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
484 ConversionPatternRewriter &rewriter) const override {
485
486 Type type = this->getTypeConverter()->convertType(op.getType());
487 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
488 return rewriter.notifyMatchFailure(
489 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
490 }
491
492 if (type.isInteger(1)) {
493 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
494 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
495 }
496
497 Type arithmeticType = type;
498 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
499 !bitEnumContainsAll(op.getOverflowFlags(),
500 arith::IntegerOverflowFlags::nsw)) {
501 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
502 // we compute in unsigned integers to avoid UB.
503 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
504 /*isSigned=*/false);
505 }
506
507 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
508 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
509
510 Value arithmeticResult =
511 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
512
513 Value result = adaptValueType(arithmeticResult, rewriter, type);
514
515 rewriter.replaceOp(op, result);
516 return success();
517 }
518};
519
520template <typename ArithOp, typename EmitCOp>
521class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
522public:
523 using OpConversionPattern<ArithOp>::OpConversionPattern;
524
525 LogicalResult
526 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
527 ConversionPatternRewriter &rewriter) const override {
528
529 Type type = this->getTypeConverter()->convertType(op.getType());
530 if (!isa_and_nonnull<IntegerType>(type)) {
531 return rewriter.notifyMatchFailure(
532 op,
533 "expected integer type, vector/tensor support not yet implemented");
534 }
535
536 // Bitwise ops can be performed directly on booleans
537 if (type.isInteger(1)) {
538 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
539 adaptor.getRhs());
540 return success();
541 }
542
543 // Bitwise ops are defined by the C standard on unsigned operands.
544 Type arithmeticType =
545 adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
546
547 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
548 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
549
550 Value arithmeticResult =
551 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
552
553 Value result = adaptValueType(arithmeticResult, rewriter, type);
554
555 rewriter.replaceOp(op, result);
556 return success();
557 }
558};
559
560template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
561class ShiftOpConversion : public OpConversionPattern<ArithOp> {
562public:
563 using OpConversionPattern<ArithOp>::OpConversionPattern;
564
565 LogicalResult
566 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
567 ConversionPatternRewriter &rewriter) const override {
568
569 Type type = this->getTypeConverter()->convertType(op.getType());
570 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
571 return rewriter.notifyMatchFailure(
572 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
573 }
574
575 if (type.isInteger(1)) {
576 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
577 }
578
579 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
580
581 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
582 // Shift amount interpreted as unsigned per Arith dialect spec.
583 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
584 /*needsUnsigned=*/true);
585 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
586
587 // Add a runtime check for overflow
588 Value width;
589 if (emitc::isPointerWideType(type)) {
590 Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
591 rewriter.getIndexAttr(8));
592 emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
593 rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
594 width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
595 sizeOfCall.getResult(0));
596 } else {
597 width = emitc::ConstantOp::create(
598 rewriter, op.getLoc(), rhsType,
599 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
600 }
601
602 Value excessCheck =
603 emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
604 emitc::CmpPredicate::lt, rhs, width);
605
606 // Any concrete value is a valid refinement of poison.
607 Value poison = emitc::ConstantOp::create(
608 rewriter, op.getLoc(), arithmeticType,
609 (isa<IntegerType>(arithmeticType)
610 ? rewriter.getIntegerAttr(arithmeticType, 0)
611 : rewriter.getIndexAttr(0)));
612
613 emitc::ExpressionOp ternary =
614 emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
615 ValueRange({lhs, rhs, excessCheck, poison}),
616 /*do_not_inline=*/false);
617 Block &bodyBlock = ternary.createBody();
618 auto currentPoint = rewriter.getInsertionPoint();
619 rewriter.setInsertionPointToStart(&bodyBlock);
620 Value arithmeticResult =
621 EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
622 bodyBlock.getArgument(0), bodyBlock.getArgument(1));
623 Value resultOrPoison = emitc::ConditionalOp::create(
624 rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
625 arithmeticResult, bodyBlock.getArgument(3));
626 emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
627 rewriter.setInsertionPoint(op->getBlock(), currentPoint);
628
629 Value result = adaptValueType(ternary, rewriter, type);
630
631 rewriter.replaceOp(op, result);
632 return success();
633 }
634};
635
636template <typename ArithOp, typename EmitCOp>
637class SignedShiftOpConversion final
638 : public ShiftOpConversion<ArithOp, EmitCOp, false> {
639 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
640};
641
642template <typename ArithOp, typename EmitCOp>
643class UnsignedShiftOpConversion final
644 : public ShiftOpConversion<ArithOp, EmitCOp, true> {
645 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
646};
647
648class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
649public:
650 using Base::Base;
651
652 LogicalResult
653 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
654 ConversionPatternRewriter &rewriter) const override {
655
656 Type dstType = getTypeConverter()->convertType(selectOp.getType());
657 if (!dstType)
658 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
659
660 if (!adaptor.getCondition().getType().isInteger(1))
661 return rewriter.notifyMatchFailure(
662 selectOp,
663 "can only be converted if condition is a scalar of type i1");
664
665 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
666 adaptor.getOperands());
667
668 return success();
669 }
670};
671
672// Floating-point to integer conversions.
673template <typename CastOp>
674class FtoICastOpConversion : public OpConversionPattern<CastOp> {
675public:
676 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
677 : OpConversionPattern<CastOp>(typeConverter, context) {}
678
679 LogicalResult
680 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter) const override {
682
683 Type operandType = adaptor.getIn().getType();
684 if (!emitc::isSupportedFloatType(operandType))
685 return rewriter.notifyMatchFailure(castOp,
686 "unsupported cast source type");
687
688 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
689 if (!dstType)
690 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
691
692 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
693 // truncated to 0, whereas a boolean conversion would return true.
694 if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
695 return rewriter.notifyMatchFailure(castOp,
696 "unsupported cast destination type");
697
698 // Convert to unsigned if it's the "ui" variant
699 // Signless is interpreted as signed, so no need to cast for "si"
700 Type actualResultType = dstType;
701 if (isa<arith::FPToUIOp>(castOp)) {
702 actualResultType =
703 rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
704 /*isSigned=*/false);
705 }
706
707 Value result = emitc::CastOp::create(
708 rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
709
710 if (isa<arith::FPToUIOp>(castOp)) {
711 result =
712 emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
713 }
714 rewriter.replaceOp(castOp, result);
715
716 return success();
717 }
718};
719
720// Integer to floating-point conversions.
721template <typename CastOp>
722class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
723public:
724 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
725 : OpConversionPattern<CastOp>(typeConverter, context) {}
726
727 LogicalResult
728 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
729 ConversionPatternRewriter &rewriter) const override {
730 // Vectors in particular are not supported
731 Type operandType = adaptor.getIn().getType();
732 if (!emitc::isSupportedIntegerType(operandType))
733 return rewriter.notifyMatchFailure(castOp,
734 "unsupported cast source type");
735
736 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
737 if (!dstType)
738 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
739
740 if (!emitc::isSupportedFloatType(dstType))
741 return rewriter.notifyMatchFailure(castOp,
742 "unsupported cast destination type");
743
744 // Convert to unsigned if it's the "ui" variant
745 // Signless is interpreted as signed, so no need to cast for "si"
746 Type actualOperandType = operandType;
747 if (isa<arith::UIToFPOp>(castOp)) {
748 actualOperandType =
749 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
750 /*isSigned=*/false);
751 }
752 Value fpCastOperand = adaptor.getIn();
753 if (actualOperandType != operandType) {
754 fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
755 actualOperandType, fpCastOperand);
756 }
757 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
758
759 return success();
760 }
761};
762
763// Floating-point to floating-point conversions.
764template <typename CastOp>
765class FpCastOpConversion : public OpConversionPattern<CastOp> {
766public:
767 FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
768 : OpConversionPattern<CastOp>(typeConverter, context) {}
769
770 LogicalResult
771 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
772 ConversionPatternRewriter &rewriter) const override {
773 // Vectors in particular are not supported.
774 Type operandType = adaptor.getIn().getType();
775 if (!emitc::isSupportedFloatType(operandType))
776 return rewriter.notifyMatchFailure(castOp,
777 "unsupported cast source type");
778 if (auto roundingModeOp =
779 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
780 // Only supporting default rounding mode as of now.
781 if (roundingModeOp.getRoundingModeAttr())
782 return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
783 }
784
785 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
786 if (!dstType)
787 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
788
789 if (!emitc::isSupportedFloatType(dstType))
790 return rewriter.notifyMatchFailure(castOp,
791 "unsupported cast destination type");
792
793 Value fpCastOperand = adaptor.getIn();
794 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
795
796 return success();
797 }
798};
799
800} // namespace
801
802//===----------------------------------------------------------------------===//
803// Pattern population
804//===----------------------------------------------------------------------===//
805
807 RewritePatternSet &patterns) {
808 MLIRContext *ctx = patterns.getContext();
809
811
812 // clang-format off
813 patterns.add<
814 ArithConstantOpConversionPattern,
815 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
816 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
817 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
818 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
819 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
820 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
821 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
822 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
823 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
824 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
825 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
826 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
827 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
828 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
829 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
830 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
831 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
832 CmpFOpConversion,
833 CmpIOpConversion,
834 NegFOpConversion,
835 SelectOpConversion,
836 // Truncation is guaranteed for unsigned types.
837 UnsignedCastConversion<arith::TruncIOp>,
838 SignedCastConversion<arith::ExtSIOp>,
839 UnsignedCastConversion<arith::ExtUIOp>,
840 SignedCastConversion<arith::IndexCastOp>,
841 UnsignedCastConversion<arith::IndexCastUIOp>,
842 ItoFCastOpConversion<arith::SIToFPOp>,
843 ItoFCastOpConversion<arith::UIToFPOp>,
844 FtoICastOpConversion<arith::FPToSIOp>,
845 FtoICastOpConversion<arith::FPToUIOp>,
846 FpCastOpConversion<arith::ExtFOp>,
847 FpCastOpConversion<arith::TruncFOp>
848 >(typeConverter, ctx);
849 // clang-format on
850}
return success()
lhs
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition EmitC.cpp:117
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition EmitC.cpp:132
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition EmitC.cpp:96
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)