MLIR 23.0.0git
ArithToEmitC.cpp
Go to the documentation of this file.
1//===- ArithToEmitC.cpp - Arith to EmitC Patterns ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns to convert the Arith dialect to the EmitC
10// dialect.
11//
12//===----------------------------------------------------------------------===//
13
15
23
24using namespace mlir;
25
26namespace {
27/// Implement the interface to convert Arith to EmitC.
28struct ArithToEmitCDialectInterface : public ConvertToEmitCPatternInterface {
29 ArithToEmitCDialectInterface(Dialect *dialect)
30 : ConvertToEmitCPatternInterface(dialect) {}
31
32 /// Hook for derived dialect interface to provide conversion patterns
33 /// and mark dialect legal for the conversion target.
34 void populateConvertToEmitCConversionPatterns(
35 ConversionTarget &target, TypeConverter &typeConverter,
36 RewritePatternSet &patterns) const final {
37 populateArithToEmitCPatterns(typeConverter, patterns);
38 }
39};
40} // namespace
41
43 registry.addExtension(+[](MLIRContext *ctx, arith::ArithDialect *dialect) {
44 dialect->addInterfaces<ArithToEmitCDialectInterface>();
45 });
46}
47
48//===----------------------------------------------------------------------===//
49// Conversion Patterns
50//===----------------------------------------------------------------------===//
51
52namespace {
53class ArithConstantOpConversionPattern
54 : public OpConversionPattern<arith::ConstantOp> {
55public:
56 using Base::Base;
57
58 LogicalResult
59 matchAndRewrite(arith::ConstantOp arithConst,
60 arith::ConstantOp::Adaptor adaptor,
61 ConversionPatternRewriter &rewriter) const override {
62 Type newTy = this->getTypeConverter()->convertType(arithConst.getType());
63 if (!newTy)
64 return rewriter.notifyMatchFailure(arithConst, "type conversion failed");
65 rewriter.replaceOpWithNewOp<emitc::ConstantOp>(arithConst, newTy,
66 adaptor.getValue());
67 return success();
68 }
69};
70
71/// Get the signed or unsigned type corresponding to \p ty.
72Type adaptIntegralTypeSignedness(Type ty, bool needsUnsigned) {
73 if (isa<IntegerType>(ty)) {
74 if (ty.isUnsignedInteger() != needsUnsigned) {
75 auto signedness = needsUnsigned
76 ? IntegerType::SignednessSemantics::Unsigned
77 : IntegerType::SignednessSemantics::Signed;
78 return IntegerType::get(ty.getContext(), ty.getIntOrFloatBitWidth(),
79 signedness);
80 }
81 } else if (emitc::isPointerWideType(ty)) {
82 if (isa<emitc::SizeTType>(ty) != needsUnsigned) {
83 if (needsUnsigned)
84 return emitc::SizeTType::get(ty.getContext());
85 return emitc::PtrDiffTType::get(ty.getContext());
86 }
87 }
88 return ty;
89}
90
91/// Insert a cast operation to type \p ty if \p val does not have this type.
92Value adaptValueType(Value val, ConversionPatternRewriter &rewriter, Type ty) {
93 return rewriter.createOrFold<emitc::CastOp>(val.getLoc(), ty, val);
94}
95
96class CmpFOpConversion : public OpConversionPattern<arith::CmpFOp> {
97public:
98 using Base::Base;
99
100 LogicalResult
101 matchAndRewrite(arith::CmpFOp op, OpAdaptor adaptor,
102 ConversionPatternRewriter &rewriter) const override {
103
104 if (!isa<FloatType>(adaptor.getRhs().getType())) {
105 return rewriter.notifyMatchFailure(op.getLoc(),
106 "cmpf currently only supported on "
107 "floats, not tensors/vectors thereof");
108 }
109
110 bool unordered = false;
111 emitc::CmpPredicate predicate;
112 switch (op.getPredicate()) {
113 case arith::CmpFPredicate::AlwaysFalse: {
114 auto constant =
115 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
116 rewriter.getBoolAttr(/*value=*/false));
117 rewriter.replaceOp(op, constant);
118 return success();
119 }
120 case arith::CmpFPredicate::OEQ:
121 unordered = false;
122 predicate = emitc::CmpPredicate::eq;
123 break;
124 case arith::CmpFPredicate::OGT:
125 unordered = false;
126 predicate = emitc::CmpPredicate::gt;
127 break;
128 case arith::CmpFPredicate::OGE:
129 unordered = false;
130 predicate = emitc::CmpPredicate::ge;
131 break;
132 case arith::CmpFPredicate::OLT:
133 unordered = false;
134 predicate = emitc::CmpPredicate::lt;
135 break;
136 case arith::CmpFPredicate::OLE:
137 unordered = false;
138 predicate = emitc::CmpPredicate::le;
139 break;
140 case arith::CmpFPredicate::ONE:
141 unordered = false;
142 predicate = emitc::CmpPredicate::ne;
143 break;
144 case arith::CmpFPredicate::ORD: {
145 // ordered, i.e. none of the operands is NaN
146 auto cmp = createCheckIsOrdered(rewriter, op.getLoc(), adaptor.getLhs(),
147 adaptor.getRhs());
148 rewriter.replaceOp(op, cmp);
149 return success();
150 }
151 case arith::CmpFPredicate::UEQ:
152 unordered = true;
153 predicate = emitc::CmpPredicate::eq;
154 break;
155 case arith::CmpFPredicate::UGT:
156 unordered = true;
157 predicate = emitc::CmpPredicate::gt;
158 break;
159 case arith::CmpFPredicate::UGE:
160 unordered = true;
161 predicate = emitc::CmpPredicate::ge;
162 break;
163 case arith::CmpFPredicate::ULT:
164 unordered = true;
165 predicate = emitc::CmpPredicate::lt;
166 break;
167 case arith::CmpFPredicate::ULE:
168 unordered = true;
169 predicate = emitc::CmpPredicate::le;
170 break;
171 case arith::CmpFPredicate::UNE:
172 unordered = true;
173 predicate = emitc::CmpPredicate::ne;
174 break;
175 case arith::CmpFPredicate::UNO: {
176 // unordered, i.e. either operand is nan
177 auto cmp = createCheckIsUnordered(rewriter, op.getLoc(), adaptor.getLhs(),
178 adaptor.getRhs());
179 rewriter.replaceOp(op, cmp);
180 return success();
181 }
182 case arith::CmpFPredicate::AlwaysTrue: {
183 auto constant =
184 emitc::ConstantOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
185 rewriter.getBoolAttr(/*value=*/true));
186 rewriter.replaceOp(op, constant);
187 return success();
188 }
189 }
190
191 // Compare the values naively
192 auto cmpResult =
193 emitc::CmpOp::create(rewriter, op.getLoc(), op.getType(), predicate,
194 adaptor.getLhs(), adaptor.getRhs());
195
196 // Adjust the results for unordered/ordered semantics
197 if (unordered) {
198 auto isUnordered = createCheckIsUnordered(
199 rewriter, op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
200 rewriter.replaceOpWithNewOp<emitc::LogicalOrOp>(op, op.getType(),
201 isUnordered, cmpResult);
202 return success();
203 }
204
205 auto isOrdered = createCheckIsOrdered(rewriter, op.getLoc(),
206 adaptor.getLhs(), adaptor.getRhs());
207 rewriter.replaceOpWithNewOp<emitc::LogicalAndOp>(op, op.getType(),
208 isOrdered, cmpResult);
209 return success();
210 }
211
212private:
213 /// Return a value that is true if \p operand is NaN.
214 Value isNaN(ConversionPatternRewriter &rewriter, Location loc,
215 Value operand) const {
216 // A value is NaN exactly when it compares unequal to itself.
217 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
218 emitc::CmpPredicate::ne, operand, operand);
219 }
220
221 /// Return a value that is true if \p operand is not NaN.
222 Value isNotNaN(ConversionPatternRewriter &rewriter, Location loc,
223 Value operand) const {
224 // A value is not NaN exactly when it compares equal to itself.
225 return emitc::CmpOp::create(rewriter, loc, rewriter.getI1Type(),
226 emitc::CmpPredicate::eq, operand, operand);
227 }
228
229 /// Return a value that is true if the operands \p first and \p second are
230 /// unordered (i.e., at least one of them is NaN).
231 Value createCheckIsUnordered(ConversionPatternRewriter &rewriter,
232 Location loc, Value first, Value second) const {
233 auto firstIsNaN = isNaN(rewriter, loc, first);
234 auto secondIsNaN = isNaN(rewriter, loc, second);
235 return emitc::LogicalOrOp::create(rewriter, loc, rewriter.getI1Type(),
236 firstIsNaN, secondIsNaN);
237 }
238
239 /// Return a value that is true if the operands \p first and \p second are
240 /// both ordered (i.e., none one of them is NaN).
241 Value createCheckIsOrdered(ConversionPatternRewriter &rewriter, Location loc,
242 Value first, Value second) const {
243 auto firstIsNotNaN = isNotNaN(rewriter, loc, first);
244 auto secondIsNotNaN = isNotNaN(rewriter, loc, second);
245 return emitc::LogicalAndOp::create(rewriter, loc, rewriter.getI1Type(),
246 firstIsNotNaN, secondIsNotNaN);
247 }
248};
249
250class CmpIOpConversion : public OpConversionPattern<arith::CmpIOp> {
251public:
252 using Base::Base;
253
254 bool needsUnsignedCmp(arith::CmpIPredicate pred) const {
255 switch (pred) {
256 case arith::CmpIPredicate::eq:
257 case arith::CmpIPredicate::ne:
258 case arith::CmpIPredicate::slt:
259 case arith::CmpIPredicate::sle:
260 case arith::CmpIPredicate::sgt:
261 case arith::CmpIPredicate::sge:
262 return false;
263 case arith::CmpIPredicate::ult:
264 case arith::CmpIPredicate::ule:
265 case arith::CmpIPredicate::ugt:
266 case arith::CmpIPredicate::uge:
267 return true;
268 }
269 llvm_unreachable("unknown cmpi predicate kind");
270 }
271
272 emitc::CmpPredicate toEmitCPred(arith::CmpIPredicate pred) const {
273 switch (pred) {
274 case arith::CmpIPredicate::eq:
275 return emitc::CmpPredicate::eq;
276 case arith::CmpIPredicate::ne:
277 return emitc::CmpPredicate::ne;
278 case arith::CmpIPredicate::slt:
279 case arith::CmpIPredicate::ult:
280 return emitc::CmpPredicate::lt;
281 case arith::CmpIPredicate::sle:
282 case arith::CmpIPredicate::ule:
283 return emitc::CmpPredicate::le;
284 case arith::CmpIPredicate::sgt:
285 case arith::CmpIPredicate::ugt:
286 return emitc::CmpPredicate::gt;
287 case arith::CmpIPredicate::sge:
288 case arith::CmpIPredicate::uge:
289 return emitc::CmpPredicate::ge;
290 }
291 llvm_unreachable("unknown cmpi predicate kind");
292 }
293
294 LogicalResult
295 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
296 ConversionPatternRewriter &rewriter) const override {
297
298 Type type = adaptor.getLhs().getType();
299 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
300 return rewriter.notifyMatchFailure(
301 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
302 }
303
304 bool needsUnsigned = needsUnsignedCmp(op.getPredicate());
305 emitc::CmpPredicate pred = toEmitCPred(op.getPredicate());
306
307 Type arithmeticType = adaptIntegralTypeSignedness(type, needsUnsigned);
308 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
309 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
310
311 rewriter.replaceOpWithNewOp<emitc::CmpOp>(op, op.getType(), pred, lhs, rhs);
312 return success();
313 }
314};
315
316class NegFOpConversion : public OpConversionPattern<arith::NegFOp> {
317public:
318 using Base::Base;
319
320 LogicalResult
321 matchAndRewrite(arith::NegFOp op, OpAdaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323
324 auto adaptedOp = adaptor.getOperand();
325 auto adaptedOpType = adaptedOp.getType();
326
327 if (isa<TensorType>(adaptedOpType) || isa<VectorType>(adaptedOpType)) {
328 return rewriter.notifyMatchFailure(
329 op.getLoc(),
330 "negf currently only supports scalar types, not vectors or tensors");
331 }
332
333 if (!emitc::isSupportedFloatType(adaptedOpType)) {
334 return rewriter.notifyMatchFailure(
335 op.getLoc(), "floating-point type is not supported by EmitC");
336 }
337
338 rewriter.replaceOpWithNewOp<emitc::UnaryMinusOp>(op, adaptedOpType,
339 adaptedOp);
340 return success();
341 }
342};
343
344template <typename ArithOp, bool castToUnsigned>
345class CastConversion : public OpConversionPattern<ArithOp> {
346public:
347 using OpConversionPattern<ArithOp>::OpConversionPattern;
348
349 LogicalResult
350 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
351 ConversionPatternRewriter &rewriter) const override {
352
353 Type opReturnType = this->getTypeConverter()->convertType(op.getType());
354 if (!opReturnType || !(isa<IntegerType>(opReturnType) ||
355 emitc::isPointerWideType(opReturnType)))
356 return rewriter.notifyMatchFailure(
357 op, "expected integer or size_t/ssize_t/ptrdiff_t result type");
358
359 if (adaptor.getOperands().size() != 1) {
360 return rewriter.notifyMatchFailure(
361 op, "CastConversion only supports unary ops");
362 }
363
364 Type operandType = adaptor.getIn().getType();
365 if (!operandType || !(isa<IntegerType>(operandType) ||
366 emitc::isPointerWideType(operandType)))
367 return rewriter.notifyMatchFailure(
368 op, "expected integer or size_t/ssize_t/ptrdiff_t operand type");
369
370 // Signed (sign-extending) casts from i1 are not supported.
371 if (operandType.isInteger(1) && !castToUnsigned)
372 return rewriter.notifyMatchFailure(op,
373 "operation not supported on i1 type");
374
375 // to-i1 conversions: arith semantics want truncation, whereas (bool)(v) is
376 // equivalent to (v != 0). Implementing as (bool)(v & 0x01) gives
377 // truncation.
378 if (opReturnType.isInteger(1)) {
379 Type attrType = (emitc::isPointerWideType(operandType))
380 ? rewriter.getIndexType()
381 : operandType;
382 auto constOne = emitc::ConstantOp::create(
383 rewriter, op.getLoc(), operandType, rewriter.getOneAttr(attrType));
384 auto oneAndOperand = emitc::BitwiseAndOp::create(
385 rewriter, op.getLoc(), operandType, adaptor.getIn(), constOne);
386 rewriter.replaceOpWithNewOp<emitc::CastOp>(op, opReturnType,
387 oneAndOperand);
388 return success();
389 }
390
391 bool isTruncation =
392 (isa<IntegerType>(operandType) && isa<IntegerType>(opReturnType) &&
393 operandType.getIntOrFloatBitWidth() >
394 opReturnType.getIntOrFloatBitWidth());
395 bool doUnsigned = castToUnsigned || isTruncation;
396
397 // Adapt the signedness of the result (bitwidth-preserving cast)
398 // This is needed e.g., if the return type is signless.
399 Type castDestType = adaptIntegralTypeSignedness(opReturnType, doUnsigned);
400
401 // Adapt the signedness of the operand (bitwidth-preserving cast)
402 Type castSrcType = adaptIntegralTypeSignedness(operandType, doUnsigned);
403 Value actualOp = adaptValueType(adaptor.getIn(), rewriter, castSrcType);
404
405 // Actual cast (may change bitwidth)
406 auto cast =
407 emitc::CastOp::create(rewriter, op.getLoc(), castDestType, actualOp);
408
409 // Cast to the expected output type
410 auto result = adaptValueType(cast, rewriter, opReturnType);
411
412 rewriter.replaceOp(op, result);
413 return success();
414 }
415};
416
417template <typename ArithOp>
418class UnsignedCastConversion : public CastConversion<ArithOp, true> {
419 using CastConversion<ArithOp, true>::CastConversion;
420};
421
422template <typename ArithOp>
423class SignedCastConversion : public CastConversion<ArithOp, false> {
424 using CastConversion<ArithOp, false>::CastConversion;
425};
426
427template <typename ArithOp, typename EmitCOp>
428class ArithOpConversion final : public OpConversionPattern<ArithOp> {
429public:
430 using OpConversionPattern<ArithOp>::OpConversionPattern;
431
432 LogicalResult
433 matchAndRewrite(ArithOp arithOp, typename ArithOp::Adaptor adaptor,
434 ConversionPatternRewriter &rewriter) const override {
435
436 Type newTy = this->getTypeConverter()->convertType(arithOp.getType());
437 if (!newTy)
438 return rewriter.notifyMatchFailure(arithOp,
439 "converting result type failed");
440 rewriter.template replaceOpWithNewOp<EmitCOp>(arithOp, newTy,
441 adaptor.getOperands());
442
443 return success();
444 }
445};
446
447template <class ArithOp, class EmitCOp>
448class BinaryUIOpConversion final : public OpConversionPattern<ArithOp> {
449public:
450 using OpConversionPattern<ArithOp>::OpConversionPattern;
451
452 LogicalResult
453 matchAndRewrite(ArithOp uiBinOp, typename ArithOp::Adaptor adaptor,
454 ConversionPatternRewriter &rewriter) const override {
455 Type newRetTy = this->getTypeConverter()->convertType(uiBinOp.getType());
456 if (!newRetTy)
457 return rewriter.notifyMatchFailure(uiBinOp,
458 "converting result type failed");
459 if (!isa<IntegerType>(newRetTy)) {
460 return rewriter.notifyMatchFailure(uiBinOp, "expected integer type");
461 }
462 Type unsignedType =
463 adaptIntegralTypeSignedness(newRetTy, /*needsUnsigned=*/true);
464 if (!unsignedType)
465 return rewriter.notifyMatchFailure(uiBinOp,
466 "converting result type failed");
467 Value lhsAdapted = adaptValueType(uiBinOp.getLhs(), rewriter, unsignedType);
468 Value rhsAdapted = adaptValueType(uiBinOp.getRhs(), rewriter, unsignedType);
469
470 auto newDivOp = EmitCOp::create(rewriter, uiBinOp.getLoc(), unsignedType,
471 ArrayRef<Value>{lhsAdapted, rhsAdapted});
472 Value resultAdapted = adaptValueType(newDivOp, rewriter, newRetTy);
473 rewriter.replaceOp(uiBinOp, resultAdapted);
474 return success();
475 }
476};
477
478template <typename ArithOp, typename EmitCOp>
479class IntegerOpConversion final : public OpConversionPattern<ArithOp> {
480public:
481 using OpConversionPattern<ArithOp>::OpConversionPattern;
482
483 LogicalResult
484 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
485 ConversionPatternRewriter &rewriter) const override {
486
487 Type type = this->getTypeConverter()->convertType(op.getType());
488 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
489 return rewriter.notifyMatchFailure(
490 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
491 }
492
493 if (type.isInteger(1)) {
494 // arith expects wrap-around arithmethic, which doesn't happen on `bool`.
495 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
496 }
497
498 Type arithmeticType = type;
499 if ((type.isSignlessInteger() || type.isSignedInteger()) &&
500 !bitEnumContainsAll(op.getOverflowFlags(),
501 arith::IntegerOverflowFlags::nsw)) {
502 // If the C type is signed and the op doesn't guarantee "No Signed Wrap",
503 // we compute in unsigned integers to avoid UB.
504 arithmeticType = rewriter.getIntegerType(type.getIntOrFloatBitWidth(),
505 /*isSigned=*/false);
506 }
507
508 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
509 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
510
511 Value arithmeticResult =
512 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
513
514 Value result = adaptValueType(arithmeticResult, rewriter, type);
515
516 rewriter.replaceOp(op, result);
517 return success();
518 }
519};
520
521template <typename ArithOp, typename EmitCOp>
522class BitwiseOpConversion : public OpConversionPattern<ArithOp> {
523public:
524 using OpConversionPattern<ArithOp>::OpConversionPattern;
525
526 LogicalResult
527 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
528 ConversionPatternRewriter &rewriter) const override {
529
530 Type type = this->getTypeConverter()->convertType(op.getType());
531 if (!isa_and_nonnull<IntegerType>(type)) {
532 return rewriter.notifyMatchFailure(
533 op,
534 "expected integer type, vector/tensor support not yet implemented");
535 }
536
537 // Bitwise ops can be performed directly on booleans
538 if (type.isInteger(1)) {
539 rewriter.replaceOpWithNewOp<EmitCOp>(op, type, adaptor.getLhs(),
540 adaptor.getRhs());
541 return success();
542 }
543
544 // Bitwise ops are defined by the C standard on unsigned operands.
545 Type arithmeticType =
546 adaptIntegralTypeSignedness(type, /*needsUnsigned=*/true);
547
548 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
549 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, arithmeticType);
550
551 Value arithmeticResult =
552 EmitCOp::create(rewriter, op.getLoc(), arithmeticType, lhs, rhs);
553
554 Value result = adaptValueType(arithmeticResult, rewriter, type);
555
556 rewriter.replaceOp(op, result);
557 return success();
558 }
559};
560
561template <typename ArithOp, typename EmitCOp, bool isUnsignedOp>
562class ShiftOpConversion : public OpConversionPattern<ArithOp> {
563public:
564 using OpConversionPattern<ArithOp>::OpConversionPattern;
565
566 LogicalResult
567 matchAndRewrite(ArithOp op, typename ArithOp::Adaptor adaptor,
568 ConversionPatternRewriter &rewriter) const override {
569
570 Type type = this->getTypeConverter()->convertType(op.getType());
571 if (!type || !(isa<IntegerType>(type) || emitc::isPointerWideType(type))) {
572 return rewriter.notifyMatchFailure(
573 op, "expected integer or size_t/ssize_t/ptrdiff_t type");
574 }
575
576 if (type.isInteger(1)) {
577 return rewriter.notifyMatchFailure(op, "i1 type is not implemented");
578 }
579
580 Type arithmeticType = adaptIntegralTypeSignedness(type, isUnsignedOp);
581
582 Value lhs = adaptValueType(adaptor.getLhs(), rewriter, arithmeticType);
583 // Shift amount interpreted as unsigned per Arith dialect spec.
584 Type rhsType = adaptIntegralTypeSignedness(adaptor.getRhs().getType(),
585 /*needsUnsigned=*/true);
586 Value rhs = adaptValueType(adaptor.getRhs(), rewriter, rhsType);
587
588 // Add a runtime check for overflow
589 Value width;
590 if (emitc::isPointerWideType(type)) {
591 Value eight = emitc::ConstantOp::create(rewriter, op.getLoc(), rhsType,
592 rewriter.getIndexAttr(8));
593 emitc::CallOpaqueOp sizeOfCall = emitc::CallOpaqueOp::create(
594 rewriter, op.getLoc(), rhsType, "sizeof", ArrayRef<Value>{eight});
595 width = emitc::MulOp::create(rewriter, op.getLoc(), rhsType, eight,
596 sizeOfCall.getResult(0));
597 } else {
598 width = emitc::ConstantOp::create(
599 rewriter, op.getLoc(), rhsType,
600 rewriter.getIntegerAttr(rhsType, type.getIntOrFloatBitWidth()));
601 }
602
603 Value excessCheck =
604 emitc::CmpOp::create(rewriter, op.getLoc(), rewriter.getI1Type(),
605 emitc::CmpPredicate::lt, rhs, width);
606
607 // Any concrete value is a valid refinement of poison.
608 Value poison = emitc::ConstantOp::create(
609 rewriter, op.getLoc(), arithmeticType,
610 (isa<IntegerType>(arithmeticType)
611 ? rewriter.getIntegerAttr(arithmeticType, 0)
612 : rewriter.getIndexAttr(0)));
613
614 emitc::ExpressionOp ternary =
615 emitc::ExpressionOp::create(rewriter, op.getLoc(), arithmeticType,
616 ValueRange({lhs, rhs, excessCheck, poison}),
617 /*do_not_inline=*/false);
618 Block &bodyBlock = ternary.createBody();
619 auto currentPoint = rewriter.getInsertionPoint();
620 rewriter.setInsertionPointToStart(&bodyBlock);
621 Value arithmeticResult =
622 EmitCOp::create(rewriter, op.getLoc(), arithmeticType,
623 bodyBlock.getArgument(0), bodyBlock.getArgument(1));
624 Value resultOrPoison = emitc::ConditionalOp::create(
625 rewriter, op.getLoc(), arithmeticType, bodyBlock.getArgument(2),
626 arithmeticResult, bodyBlock.getArgument(3));
627 emitc::YieldOp::create(rewriter, op.getLoc(), resultOrPoison);
628 rewriter.setInsertionPoint(op->getBlock(), currentPoint);
629
630 Value result = adaptValueType(ternary, rewriter, type);
631
632 rewriter.replaceOp(op, result);
633 return success();
634 }
635};
636
637template <typename ArithOp, typename EmitCOp>
638class SignedShiftOpConversion final
639 : public ShiftOpConversion<ArithOp, EmitCOp, false> {
640 using ShiftOpConversion<ArithOp, EmitCOp, false>::ShiftOpConversion;
641};
642
643template <typename ArithOp, typename EmitCOp>
644class UnsignedShiftOpConversion final
645 : public ShiftOpConversion<ArithOp, EmitCOp, true> {
646 using ShiftOpConversion<ArithOp, EmitCOp, true>::ShiftOpConversion;
647};
648
649class SelectOpConversion : public OpConversionPattern<arith::SelectOp> {
650public:
651 using Base::Base;
652
653 LogicalResult
654 matchAndRewrite(arith::SelectOp selectOp, OpAdaptor adaptor,
655 ConversionPatternRewriter &rewriter) const override {
656
657 Type dstType = getTypeConverter()->convertType(selectOp.getType());
658 if (!dstType)
659 return rewriter.notifyMatchFailure(selectOp, "type conversion failed");
660
661 if (!adaptor.getCondition().getType().isInteger(1))
662 return rewriter.notifyMatchFailure(
663 selectOp,
664 "can only be converted if condition is a scalar of type i1");
665
666 rewriter.replaceOpWithNewOp<emitc::ConditionalOp>(selectOp, dstType,
667 adaptor.getOperands());
668
669 return success();
670 }
671};
672
673// Floating-point to integer conversions.
674template <typename CastOp>
675class FtoICastOpConversion : public OpConversionPattern<CastOp> {
676public:
677 FtoICastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
678 : OpConversionPattern<CastOp>(typeConverter, context) {}
679
680 LogicalResult
681 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
682 ConversionPatternRewriter &rewriter) const override {
683
684 Type operandType = adaptor.getIn().getType();
685 if (!emitc::isSupportedFloatType(operandType))
686 return rewriter.notifyMatchFailure(castOp,
687 "unsupported cast source type");
688
689 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
690 if (!dstType)
691 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
692
693 // Float-to-i1 casts are not supported: any value with 0 < value < 1 must be
694 // truncated to 0, whereas a boolean conversion would return true.
695 if (!emitc::isSupportedIntegerType(dstType) || dstType.isInteger(1))
696 return rewriter.notifyMatchFailure(castOp,
697 "unsupported cast destination type");
698
699 // Convert to unsigned if it's the "ui" variant
700 // Signless is interpreted as signed, so no need to cast for "si"
701 Type actualResultType = dstType;
702 if (isa<arith::FPToUIOp>(castOp)) {
703 actualResultType =
704 rewriter.getIntegerType(dstType.getIntOrFloatBitWidth(),
705 /*isSigned=*/false);
706 }
707
708 Value result = emitc::CastOp::create(
709 rewriter, castOp.getLoc(), actualResultType, adaptor.getOperands());
710
711 if (isa<arith::FPToUIOp>(castOp)) {
712 result =
713 emitc::CastOp::create(rewriter, castOp.getLoc(), dstType, result);
714 }
715 rewriter.replaceOp(castOp, result);
716
717 return success();
718 }
719};
720
721// Integer to floating-point conversions.
722template <typename CastOp>
723class ItoFCastOpConversion : public OpConversionPattern<CastOp> {
724public:
725 ItoFCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
726 : OpConversionPattern<CastOp>(typeConverter, context) {}
727
728 LogicalResult
729 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
730 ConversionPatternRewriter &rewriter) const override {
731 // Vectors in particular are not supported
732 Type operandType = adaptor.getIn().getType();
733 if (!emitc::isSupportedIntegerType(operandType))
734 return rewriter.notifyMatchFailure(castOp,
735 "unsupported cast source type");
736
737 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
738 if (!dstType)
739 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
740
741 if (!emitc::isSupportedFloatType(dstType))
742 return rewriter.notifyMatchFailure(castOp,
743 "unsupported cast destination type");
744
745 // Convert to unsigned if it's the "ui" variant
746 // Signless is interpreted as signed, so no need to cast for "si"
747 Type actualOperandType = operandType;
748 if (isa<arith::UIToFPOp>(castOp)) {
749 actualOperandType =
750 rewriter.getIntegerType(operandType.getIntOrFloatBitWidth(),
751 /*isSigned=*/false);
752 }
753 Value fpCastOperand = adaptor.getIn();
754 if (actualOperandType != operandType) {
755 fpCastOperand = emitc::CastOp::create(rewriter, castOp.getLoc(),
756 actualOperandType, fpCastOperand);
757 }
758 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
759
760 return success();
761 }
762};
763
764// Floating-point to floating-point conversions.
765template <typename CastOp>
766class FpCastOpConversion : public OpConversionPattern<CastOp> {
767public:
768 FpCastOpConversion(const TypeConverter &typeConverter, MLIRContext *context)
769 : OpConversionPattern<CastOp>(typeConverter, context) {}
770
771 LogicalResult
772 matchAndRewrite(CastOp castOp, typename CastOp::Adaptor adaptor,
773 ConversionPatternRewriter &rewriter) const override {
774 // Vectors in particular are not supported.
775 Type operandType = adaptor.getIn().getType();
776 if (!emitc::isSupportedFloatType(operandType))
777 return rewriter.notifyMatchFailure(castOp,
778 "unsupported cast source type");
779 if (auto roundingModeOp =
780 dyn_cast<arith::ArithRoundingModeInterface>(*castOp)) {
781 // Only supporting default rounding mode as of now.
782 if (roundingModeOp.getRoundingModeAttr())
783 return rewriter.notifyMatchFailure(castOp, "unsupported rounding mode");
784 }
785
786 Type dstType = this->getTypeConverter()->convertType(castOp.getType());
787 if (!dstType)
788 return rewriter.notifyMatchFailure(castOp, "type conversion failed");
789
790 if (!emitc::isSupportedFloatType(dstType))
791 return rewriter.notifyMatchFailure(castOp,
792 "unsupported cast destination type");
793
794 Value fpCastOperand = adaptor.getIn();
795 rewriter.replaceOpWithNewOp<emitc::CastOp>(castOp, dstType, fpCastOperand);
796
797 return success();
798 }
799};
800
801} // namespace
802
803//===----------------------------------------------------------------------===//
804// Pattern population
805//===----------------------------------------------------------------------===//
806
808 RewritePatternSet &patterns) {
809 MLIRContext *ctx = patterns.getContext();
810
812
813 // clang-format off
814 patterns.add<
815 ArithConstantOpConversionPattern,
816 ArithOpConversion<arith::AddFOp, emitc::AddOp>,
817 ArithOpConversion<arith::DivFOp, emitc::DivOp>,
818 ArithOpConversion<arith::DivSIOp, emitc::DivOp>,
819 ArithOpConversion<arith::MulFOp, emitc::MulOp>,
820 ArithOpConversion<arith::RemSIOp, emitc::RemOp>,
821 ArithOpConversion<arith::SubFOp, emitc::SubOp>,
822 BinaryUIOpConversion<arith::DivUIOp, emitc::DivOp>,
823 BinaryUIOpConversion<arith::RemUIOp, emitc::RemOp>,
824 IntegerOpConversion<arith::AddIOp, emitc::AddOp>,
825 IntegerOpConversion<arith::MulIOp, emitc::MulOp>,
826 IntegerOpConversion<arith::SubIOp, emitc::SubOp>,
827 BitwiseOpConversion<arith::AndIOp, emitc::BitwiseAndOp>,
828 BitwiseOpConversion<arith::OrIOp, emitc::BitwiseOrOp>,
829 BitwiseOpConversion<arith::XOrIOp, emitc::BitwiseXorOp>,
830 UnsignedShiftOpConversion<arith::ShLIOp, emitc::BitwiseLeftShiftOp>,
831 SignedShiftOpConversion<arith::ShRSIOp, emitc::BitwiseRightShiftOp>,
832 UnsignedShiftOpConversion<arith::ShRUIOp, emitc::BitwiseRightShiftOp>,
833 CmpFOpConversion,
834 CmpIOpConversion,
835 NegFOpConversion,
836 SelectOpConversion,
837 // Truncation is guaranteed for unsigned types.
838 UnsignedCastConversion<arith::TruncIOp>,
839 SignedCastConversion<arith::ExtSIOp>,
840 UnsignedCastConversion<arith::ExtUIOp>,
841 SignedCastConversion<arith::IndexCastOp>,
842 UnsignedCastConversion<arith::IndexCastUIOp>,
843 ItoFCastOpConversion<arith::SIToFPOp>,
844 ItoFCastOpConversion<arith::UIToFPOp>,
845 FtoICastOpConversion<arith::FPToSIOp>,
846 FtoICastOpConversion<arith::FPToUIOp>,
847 FpCastOpConversion<arith::ExtFOp>,
848 FpCastOpConversion<arith::TruncFOp>
849 >(typeConverter, ctx);
850 // clang-format on
851}
return success()
lhs
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
Definition Types.cpp:78
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
Definition Types.cpp:66
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
Definition Types.cpp:90
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition EmitC.cpp:117
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition EmitC.cpp:132
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition EmitC.cpp:96
Include the generated interface declarations.
void registerConvertArithToEmitCInterface(DialectRegistry &registry)
void populateEmitCSizeTTypeConversions(TypeConverter &converter)
void populateArithToEmitCPatterns(TypeConverter &typeConverter, RewritePatternSet &patterns)