MLIR 22.0.0git
EmitC.cpp
Go to the documentation of this file.
1//===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Builders.h"
15#include "mlir/IR/Types.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Casting.h"
22
23using namespace mlir;
24using namespace mlir::emitc;
25
26#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
27
28//===----------------------------------------------------------------------===//
29// EmitCDialect
30//===----------------------------------------------------------------------===//
31
32void EmitCDialect::initialize() {
33 addOperations<
34#define GET_OP_LIST
35#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
36 >();
37 addTypes<
38#define GET_TYPEDEF_LIST
39#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
40 >();
41 addAttributes<
42#define GET_ATTRDEF_LIST
43#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
44 >();
45}
46
47/// Materialize a single constant operation from a given attribute value with
48/// the desired resultant type.
49Operation *EmitCDialect::materializeConstant(OpBuilder &builder,
50 Attribute value, Type type,
51 Location loc) {
52 return emitc::ConstantOp::create(builder, loc, type, value);
53}
54
55/// Default callback for builders of ops carrying a region. Inserts a yield
56/// without arguments.
58 emitc::YieldOp::create(builder, loc);
59}
60
62 if (llvm::isa<emitc::OpaqueType>(type))
63 return true;
64 if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
65 return isSupportedEmitCType(ptrType.getPointee());
66 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
67 auto elemType = arrayType.getElementType();
68 return !llvm::isa<emitc::ArrayType>(elemType) &&
69 isSupportedEmitCType(elemType);
70 }
71 if (type.isIndex() || emitc::isPointerWideType(type))
72 return true;
73 if (llvm::isa<IntegerType>(type))
74 return isSupportedIntegerType(type);
75 if (llvm::isa<FloatType>(type))
76 return isSupportedFloatType(type);
77 if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
78 if (!tensorType.hasStaticShape()) {
79 return false;
80 }
81 auto elemType = tensorType.getElementType();
82 if (llvm::isa<emitc::ArrayType>(elemType)) {
83 return false;
84 }
85 return isSupportedEmitCType(elemType);
86 }
87 if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
88 return llvm::all_of(tupleType.getTypes(), [](Type type) {
89 return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
90 });
91 }
92 return false;
93}
94
96 if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
97 switch (intType.getWidth()) {
98 case 1:
99 case 8:
100 case 16:
101 case 32:
102 case 64:
103 return true;
104 default:
105 return false;
106 }
107 }
108 return false;
109}
110
112 return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
114}
115
117 if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
118 switch (floatType.getWidth()) {
119 case 16:
120 return llvm::isa<Float16Type, BFloat16Type>(type);
121 case 32:
122 case 64:
123 return true;
124 default:
125 return false;
126 }
127 }
128 return false;
129}
130
132 return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
133 type);
134}
135
137 return llvm::isa<IndexType>(type) || isPointerWideType(type) ||
139 isa<emitc::PointerType>(type);
140}
141
142/// Check that the type of the initial value is compatible with the operations
143/// result type.
145 Attribute value) {
146 assert(op->getNumResults() == 1 && "operation must have 1 result");
147
148 if (llvm::isa<emitc::OpaqueAttr>(value))
149 return success();
150
151 if (llvm::isa<StringAttr>(value))
152 return op->emitOpError()
153 << "string attributes are not supported, use #emitc.opaque instead";
154
155 Type resultType = op->getResult(0).getType();
156 if (auto lType = dyn_cast<LValueType>(resultType))
157 resultType = lType.getValueType();
158 Type attrType = cast<TypedAttr>(value).getType();
159
160 if (isPointerWideType(resultType) && attrType.isIndex())
161 return success();
162
163 if (resultType != attrType)
164 return op->emitOpError()
165 << "requires attribute to either be an #emitc.opaque attribute or "
166 "it's type ("
167 << attrType << ") to match the op's result type (" << resultType
168 << ")";
169
170 return success();
171}
172
173/// Parse a format string and return a list of its parts.
174/// A part is either a StringRef that has to be printed as-is, or
175/// a Placeholder which requires printing the next operand of the VerbatimOp.
176/// In the format string, all `{}` are replaced by Placeholders, except if the
177/// `{` is escaped by `{{` - then it doesn't start a placeholder.
178template <class ArgType>
179FailureOr<SmallVector<ReplacementItem>> parseFormatString(
180 StringRef toParse, ArgType fmtArgs,
183
184 // If there are not operands, the format string is not interpreted.
185 if (fmtArgs.empty()) {
186 items.push_back(toParse);
187 return items;
188 }
189
190 while (!toParse.empty()) {
191 size_t idx = toParse.find('{');
192 if (idx == StringRef::npos) {
193 // No '{'
194 items.push_back(toParse);
195 break;
196 }
197 if (idx > 0) {
198 // Take all chars excluding the '{'.
199 items.push_back(toParse.take_front(idx));
200 toParse = toParse.drop_front(idx);
201 continue;
202 }
203 if (toParse.size() < 2) {
204 return emitError() << "expected '}' after unescaped '{' at end of string";
205 }
206 // toParse contains at least two characters and starts with `{`.
207 char nextChar = toParse[1];
208 if (nextChar == '{') {
209 // Double '{{' -> '{' (escaping).
210 items.push_back(toParse.take_front(1));
211 toParse = toParse.drop_front(2);
212 continue;
213 }
214 if (nextChar == '}') {
215 items.push_back(Placeholder{});
216 toParse = toParse.drop_front(2);
217 continue;
218 }
219
220 if (emitError) {
221 return emitError() << "expected '}' after unescaped '{'";
222 }
223 return failure();
224 }
225 return items;
226}
227
228//===----------------------------------------------------------------------===//
229// AddOp
230//===----------------------------------------------------------------------===//
231
232LogicalResult AddOp::verify() {
233 Type lhsType = getLhs().getType();
234 Type rhsType = getRhs().getType();
235
236 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
237 return emitOpError("requires that at most one operand is a pointer");
238
239 if ((isa<emitc::PointerType>(lhsType) &&
240 !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
241 (isa<emitc::PointerType>(rhsType) &&
242 !isa<IntegerType, emitc::OpaqueType>(lhsType)))
243 return emitOpError("requires that one operand is an integer or of opaque "
244 "type if the other is a pointer");
245
246 return success();
247}
248
249//===----------------------------------------------------------------------===//
250// ApplyOp
251//===----------------------------------------------------------------------===//
252
253LogicalResult ApplyOp::verify() {
254 StringRef applicableOperatorStr = getApplicableOperator();
255
256 // Applicable operator must not be empty.
257 if (applicableOperatorStr.empty())
258 return emitOpError("applicable operator must not be empty");
259
260 // Only `*` and `&` are supported.
261 if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
262 return emitOpError("applicable operator is illegal");
263
264 Type operandType = getOperand().getType();
265 Type resultType = getResult().getType();
266 if (applicableOperatorStr == "&") {
267 if (!llvm::isa<emitc::LValueType>(operandType))
268 return emitOpError("operand type must be an lvalue when applying `&`");
269 if (!llvm::isa<emitc::PointerType>(resultType))
270 return emitOpError("result type must be a pointer when applying `&`");
271 } else {
272 if (!llvm::isa<emitc::PointerType>(operandType))
273 return emitOpError("operand type must be a pointer when applying `*`");
274 }
275
276 return success();
277}
278
279//===----------------------------------------------------------------------===//
280// AssignOp
281//===----------------------------------------------------------------------===//
282
283/// The assign op requires that the assigned value's type matches the
284/// assigned-to variable type.
285LogicalResult emitc::AssignOp::verify() {
287
288 if (!variable.getDefiningOp())
289 return emitOpError() << "cannot assign to block argument";
290
291 Type valueType = getValue().getType();
292 Type variableType = variable.getType().getValueType();
293 if (variableType != valueType)
294 return emitOpError() << "requires value's type (" << valueType
295 << ") to match variable's type (" << variableType
296 << ")\n variable: " << variable
297 << "\n value: " << getValue() << "\n";
298 return success();
299}
300
301//===----------------------------------------------------------------------===//
302// CastOp
303//===----------------------------------------------------------------------===//
304
305bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
306 Type input = inputs.front(), output = outputs.front();
307
308 if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
309 if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
310 return (arrayType.getElementType() == pointerType.getPointee()) &&
311 arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
312 }
313 return false;
314 }
315
316 return (
318 emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
320 emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
321}
322
323//===----------------------------------------------------------------------===//
324// CallOpaqueOp
325//===----------------------------------------------------------------------===//
326
327LogicalResult emitc::CallOpaqueOp::verify() {
328 // Callee must not be empty.
329 if (getCallee().empty())
330 return emitOpError("callee must not be empty");
331
332 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
333 for (Attribute arg : *argsAttr) {
334 auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
335 if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
336 int64_t index = intAttr.getInt();
337 // Args with elements of type index must be in range
338 // [0..operands.size).
339 if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
340 return emitOpError("index argument is out of range");
341
342 // Args with elements of type ArrayAttr must have a type.
343 } else if (llvm::isa<ArrayAttr>(
344 arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
345 // FIXME: Array attributes never have types
346 return emitOpError("array argument has no type");
347 }
348 }
349 }
350
351 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
352 for (Attribute tArg : *templateArgsAttr) {
353 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
354 return emitOpError("template argument has invalid type");
355 }
356 }
357
358 if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
359 return emitOpError() << "cannot return array type";
360 }
361
362 return success();
363}
364
365//===----------------------------------------------------------------------===//
366// ConstantOp
367//===----------------------------------------------------------------------===//
368
369LogicalResult emitc::ConstantOp::verify() {
370 Attribute value = getValueAttr();
371 if (failed(verifyInitializationAttribute(getOperation(), value)))
372 return failure();
373 if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
374 if (opaqueValue.getValue().empty())
375 return emitOpError() << "value must not be empty";
376 }
377 return success();
378}
379
380OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
381
382//===----------------------------------------------------------------------===//
383// ExpressionOp
384//===----------------------------------------------------------------------===//
385
386ParseResult ExpressionOp::parse(OpAsmParser &parser, OperationState &result) {
388 if (parser.parseOperandList(operands))
389 return parser.emitError(parser.getCurrentLocation()) << "expected operands";
390 if (succeeded(parser.parseOptionalKeyword("noinline")))
391 result.addAttribute(ExpressionOp::getDoNotInlineAttrName(result.name),
392 parser.getBuilder().getUnitAttr());
393 Type type;
394 if (parser.parseColonType(type))
395 return parser.emitError(parser.getCurrentLocation(),
396 "expected function type");
397 auto fnType = llvm::dyn_cast<FunctionType>(type);
398 if (!fnType)
399 return parser.emitError(parser.getCurrentLocation(),
400 "expected function type");
401 if (parser.resolveOperands(operands, fnType.getInputs(),
402 parser.getCurrentLocation(), result.operands))
403 return failure();
404 if (fnType.getNumResults() != 1)
405 return parser.emitError(parser.getCurrentLocation(),
406 "expected single return type");
407 result.addTypes(fnType.getResults());
408 Region *body = result.addRegion();
410 for (auto [unresolvedOperand, operandType] :
411 llvm::zip(operands, fnType.getInputs())) {
412 OpAsmParser::Argument argInfo;
413 argInfo.ssaName = unresolvedOperand;
414 argInfo.type = operandType;
415 argsInfo.push_back(argInfo);
416 }
417 if (parser.parseRegion(*body, argsInfo, /*enableNameShadowing=*/true))
418 return failure();
419 return success();
420}
421
422void emitc::ExpressionOp::print(OpAsmPrinter &p) {
423 p << ' ';
424 p.printOperands(getDefs());
425 p << " : ";
426 p.printFunctionalType(getOperation());
427 p.shadowRegionArgs(getRegion(), getDefs());
428 p << ' ';
429 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
430}
431
432Operation *ExpressionOp::getRootOp() {
433 auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
434 Value yieldedValue = yieldOp.getResult();
435 return yieldedValue.getDefiningOp();
436}
437
438LogicalResult ExpressionOp::verify() {
439 Type resultType = getResult().getType();
440 Region &region = getRegion();
441
442 Block &body = region.front();
443
444 if (!body.mightHaveTerminator())
445 return emitOpError("must yield a value at termination");
446
447 auto yield = cast<YieldOp>(body.getTerminator());
448 Value yieldResult = yield.getResult();
449
450 if (!yieldResult)
451 return emitOpError("must yield a value at termination");
452
453 Operation *rootOp = yieldResult.getDefiningOp();
454
455 if (!rootOp)
456 return emitOpError("yielded value has no defining op");
457
458 if (rootOp->getParentOp() != getOperation())
459 return emitOpError("yielded value not defined within expression");
460
461 Type yieldType = yieldResult.getType();
462
463 if (resultType != yieldType)
464 return emitOpError("requires yielded type to match return type");
465
466 for (Operation &op : region.front().without_terminator()) {
467 auto expressionInterface = dyn_cast<emitc::CExpressionInterface>(op);
468 if (!expressionInterface)
469 return emitOpError("contains an unsupported operation");
470 if (op.getNumResults() != 1)
471 return emitOpError("requires exactly one result for each operation");
472 Value result = op.getResult(0);
473 if (result.use_empty())
474 return emitOpError("contains an unused operation");
475 }
476
477 // Make sure any operation with side effect is only reachable once from
478 // the root op, otherwise emission will be replicating side effects.
481 worklist.push_back(rootOp);
482 while (!worklist.empty()) {
483 Operation *op = worklist.back();
484 worklist.pop_back();
485 if (visited.contains(op)) {
486 if (cast<CExpressionInterface>(op).hasSideEffects())
487 return emitOpError(
488 "requires exactly one use for operations with side effects");
489 }
490 visited.insert(op);
491 for (Value operand : op->getOperands())
492 if (Operation *def = operand.getDefiningOp()) {
493 worklist.push_back(def);
494 }
495 }
496
497 return success();
498}
499
500//===----------------------------------------------------------------------===//
501// ForOp
502//===----------------------------------------------------------------------===//
503
504void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
505 Value ub, Value step, BodyBuilderFn bodyBuilder) {
506 OpBuilder::InsertionGuard g(builder);
507 result.addOperands({lb, ub, step});
508 Type t = lb.getType();
509 Region *bodyRegion = result.addRegion();
510 Block *bodyBlock = builder.createBlock(bodyRegion);
511 bodyBlock->addArgument(t, result.location);
512
513 // Create the default terminator if the builder is not provided.
514 if (!bodyBuilder) {
515 ForOp::ensureTerminator(*bodyRegion, builder, result.location);
516 } else {
517 OpBuilder::InsertionGuard guard(builder);
518 builder.setInsertionPointToStart(bodyBlock);
519 bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
520 }
521}
522
523void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
524
525ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
526 Builder &builder = parser.getBuilder();
527 Type type;
528
529 OpAsmParser::Argument inductionVariable;
531
532 // Parse the induction variable followed by '='.
533 if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
534 // Parse loop bounds.
535 parser.parseOperand(lb) || parser.parseKeyword("to") ||
536 parser.parseOperand(ub) || parser.parseKeyword("step") ||
537 parser.parseOperand(step))
538 return failure();
539
540 // Parse the optional initial iteration arguments.
542 regionArgs.push_back(inductionVariable);
543
544 // Parse optional type, else assume Index.
545 if (parser.parseOptionalColon())
546 type = builder.getIndexType();
547 else if (parser.parseType(type))
548 return failure();
549
550 // Resolve input operands.
551 regionArgs.front().type = type;
552 if (parser.resolveOperand(lb, type, result.operands) ||
553 parser.resolveOperand(ub, type, result.operands) ||
554 parser.resolveOperand(step, type, result.operands))
555 return failure();
556
557 // Parse the body region.
558 Region *body = result.addRegion();
559 if (parser.parseRegion(*body, regionArgs))
560 return failure();
561
562 ForOp::ensureTerminator(*body, builder, result.location);
563
564 // Parse the optional attribute list.
565 if (parser.parseOptionalAttrDict(result.attributes))
566 return failure();
567
568 return success();
569}
570
571void ForOp::print(OpAsmPrinter &p) {
572 p << " " << getInductionVar() << " = " << getLowerBound() << " to "
573 << getUpperBound() << " step " << getStep();
574
575 p << ' ';
576 if (Type t = getInductionVar().getType(); !t.isIndex())
577 p << " : " << t << ' ';
578 p.printRegion(getRegion(),
579 /*printEntryBlockArgs=*/false,
580 /*printBlockTerminators=*/false);
581 p.printOptionalAttrDict((*this)->getAttrs());
582}
583
584LogicalResult ForOp::verifyRegions() {
585 // Check that the body defines as single block argument for the induction
586 // variable.
587 if (getBody()->getNumArguments() != 1)
588 return emitOpError("expected body to have a single block argument for the "
589 "induction variable");
590
591 if (getInductionVar().getType() != getLowerBound().getType())
592 return emitOpError(
593 "expected induction variable to be same type as bounds and step");
594
595 return success();
596}
597
598//===----------------------------------------------------------------------===//
599// CallOp
600//===----------------------------------------------------------------------===//
601
602LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
603 // Check that the callee attribute was specified.
604 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
605 if (!fnAttr)
606 return emitOpError("requires a 'callee' symbol reference attribute");
607 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
608 if (!fn)
609 return emitOpError() << "'" << fnAttr.getValue()
610 << "' does not reference a valid function";
611
612 // Verify that the operand and result types match the callee.
613 auto fnType = fn.getFunctionType();
614 if (fnType.getNumInputs() != getNumOperands())
615 return emitOpError("incorrect number of operands for callee");
616
617 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
618 if (getOperand(i).getType() != fnType.getInput(i))
619 return emitOpError("operand type mismatch: expected operand type ")
620 << fnType.getInput(i) << ", but provided "
621 << getOperand(i).getType() << " for operand number " << i;
622
623 if (fnType.getNumResults() != getNumResults())
624 return emitOpError("incorrect number of results for callee");
625
626 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
627 if (getResult(i).getType() != fnType.getResult(i)) {
628 auto diag = emitOpError("result type mismatch at index ") << i;
629 diag.attachNote() << " op result types: " << getResultTypes();
630 diag.attachNote() << "function result types: " << fnType.getResults();
631 return diag;
632 }
633
634 return success();
635}
636
637FunctionType CallOp::getCalleeType() {
638 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
639}
640
641//===----------------------------------------------------------------------===//
642// DeclareFuncOp
643//===----------------------------------------------------------------------===//
644
645LogicalResult
646DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
647 // Check that the sym_name attribute was specified.
648 auto fnAttr = getSymNameAttr();
649 if (!fnAttr)
650 return emitOpError("requires a 'sym_name' symbol reference attribute");
651 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
652 if (!fn)
653 return emitOpError() << "'" << fnAttr.getValue()
654 << "' does not reference a valid function";
655
656 return success();
657}
658
659//===----------------------------------------------------------------------===//
660// FuncOp
661//===----------------------------------------------------------------------===//
662
663void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
664 FunctionType type, ArrayRef<NamedAttribute> attrs,
665 ArrayRef<DictionaryAttr> argAttrs) {
667 builder.getStringAttr(name));
668 state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
669 state.attributes.append(attrs.begin(), attrs.end());
670 state.addRegion();
671
672 if (argAttrs.empty())
673 return;
674 assert(type.getNumInputs() == argAttrs.size());
676 builder, state, argAttrs, /*resultAttrs=*/{},
677 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
678}
679
680ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
681 auto buildFuncType =
682 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
684 std::string &) { return builder.getFunctionType(argTypes, results); };
685
687 parser, result, /*allowVariadic=*/false,
688 getFunctionTypeAttrName(result.name), buildFuncType,
689 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
690}
691
692void FuncOp::print(OpAsmPrinter &p) {
694 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
695 getArgAttrsAttrName(), getResAttrsAttrName());
696}
697
698LogicalResult FuncOp::verify() {
699 if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
700 return emitOpError("cannot have lvalue type as argument");
701 }
702
703 if (getNumResults() > 1)
704 return emitOpError("requires zero or exactly one result, but has ")
705 << getNumResults();
706
707 if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
708 return emitOpError("cannot return array type");
709
710 return success();
711}
712
713//===----------------------------------------------------------------------===//
714// ReturnOp
715//===----------------------------------------------------------------------===//
716
717LogicalResult ReturnOp::verify() {
718 auto function = cast<FuncOp>((*this)->getParentOp());
719
720 // The operand number and types must match the function signature.
721 if (getNumOperands() != function.getNumResults())
722 return emitOpError("has ")
723 << getNumOperands() << " operands, but enclosing function (@"
724 << function.getName() << ") returns " << function.getNumResults();
725
726 if (function.getNumResults() == 1)
727 if (getOperand().getType() != function.getResultTypes()[0])
728 return emitError() << "type of the return operand ("
729 << getOperand().getType()
730 << ") doesn't match function result type ("
731 << function.getResultTypes()[0] << ")"
732 << " in function @" << function.getName();
733 return success();
734}
735
736//===----------------------------------------------------------------------===//
737// IfOp
738//===----------------------------------------------------------------------===//
739
740void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
741 bool addThenBlock, bool addElseBlock) {
742 assert((!addElseBlock || addThenBlock) &&
743 "must not create else block w/o then block");
744 result.addOperands(cond);
745
746 // Add regions and blocks.
747 OpBuilder::InsertionGuard guard(builder);
748 Region *thenRegion = result.addRegion();
749 if (addThenBlock)
750 builder.createBlock(thenRegion);
751 Region *elseRegion = result.addRegion();
752 if (addElseBlock)
753 builder.createBlock(elseRegion);
754}
755
756void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
757 bool withElseRegion) {
758 result.addOperands(cond);
759
760 // Build then region.
761 OpBuilder::InsertionGuard guard(builder);
762 Region *thenRegion = result.addRegion();
763 builder.createBlock(thenRegion);
764
765 // Build else region.
766 Region *elseRegion = result.addRegion();
767 if (withElseRegion) {
768 builder.createBlock(elseRegion);
769 }
770}
771
772void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
773 function_ref<void(OpBuilder &, Location)> thenBuilder,
774 function_ref<void(OpBuilder &, Location)> elseBuilder) {
775 assert(thenBuilder && "the builder callback for 'then' must be present");
776 result.addOperands(cond);
777
778 // Build then region.
779 OpBuilder::InsertionGuard guard(builder);
780 Region *thenRegion = result.addRegion();
781 builder.createBlock(thenRegion);
782 thenBuilder(builder, result.location);
783
784 // Build else region.
785 Region *elseRegion = result.addRegion();
786 if (elseBuilder) {
787 builder.createBlock(elseRegion);
788 elseBuilder(builder, result.location);
789 }
790}
791
792ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
793 // Create the regions for 'then'.
794 result.regions.reserve(2);
795 Region *thenRegion = result.addRegion();
796 Region *elseRegion = result.addRegion();
797
798 Builder &builder = parser.getBuilder();
800 Type i1Type = builder.getIntegerType(1);
801 if (parser.parseOperand(cond) ||
802 parser.resolveOperand(cond, i1Type, result.operands))
803 return failure();
804 // Parse the 'then' region.
805 if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
806 return failure();
807 IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
808
809 // If we find an 'else' keyword then parse the 'else' region.
810 if (!parser.parseOptionalKeyword("else")) {
811 if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
812 return failure();
813 IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
814 }
815
816 // Parse the optional attribute list.
817 if (parser.parseOptionalAttrDict(result.attributes))
818 return failure();
819 return success();
820}
821
822void IfOp::print(OpAsmPrinter &p) {
823 bool printBlockTerminators = false;
824
825 p << " " << getCondition();
826 p << ' ';
827 p.printRegion(getThenRegion(),
828 /*printEntryBlockArgs=*/false,
829 /*printBlockTerminators=*/printBlockTerminators);
830
831 // Print the 'else' regions if it exists and has a block.
832 Region &elseRegion = getElseRegion();
833 if (!elseRegion.empty()) {
834 p << " else ";
835 p.printRegion(elseRegion,
836 /*printEntryBlockArgs=*/false,
837 /*printBlockTerminators=*/printBlockTerminators);
838 }
839
840 p.printOptionalAttrDict((*this)->getAttrs());
841}
842
843/// Given the region at `index`, or the parent operation if `index` is None,
844/// return the successor regions. These are the regions that may be selected
845/// during the flow of control. `operands` is a set of optional attributes
846/// that correspond to a constant value for each operand, or null if that
847/// operand is not a constant.
848void IfOp::getSuccessorRegions(RegionBranchPoint point,
850 // The `then` and the `else` region branch back to the parent operation.
851 if (!point.isParent()) {
852 regions.push_back(
853 RegionSuccessor(getOperation(), getOperation()->getResults()));
854 return;
855 }
856
857 regions.push_back(RegionSuccessor(&getThenRegion()));
858
859 // Don't consider the else region if it is empty.
860 Region *elseRegion = &this->getElseRegion();
861 if (elseRegion->empty())
862 regions.push_back(
863 RegionSuccessor(getOperation(), getOperation()->getResults()));
864 else
865 regions.push_back(RegionSuccessor(elseRegion));
866}
867
868void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
870 FoldAdaptor adaptor(operands, *this);
871 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
872 if (!boolAttr || boolAttr.getValue())
873 regions.emplace_back(&getThenRegion());
874
875 // If the else region is empty, execution continues after the parent op.
876 if (!boolAttr || !boolAttr.getValue()) {
877 if (!getElseRegion().empty())
878 regions.emplace_back(&getElseRegion());
879 else
880 regions.emplace_back(getOperation(), getOperation()->getResults());
881 }
882}
883
884void IfOp::getRegionInvocationBounds(
885 ArrayRef<Attribute> operands,
886 SmallVectorImpl<InvocationBounds> &invocationBounds) {
887 if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
888 // If the condition is known, then one region is known to be executed once
889 // and the other zero times.
890 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
891 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
892 } else {
893 // Non-constant condition. Each region may be executed 0 or 1 times.
894 invocationBounds.assign(2, {0, 1});
895 }
896}
897
898//===----------------------------------------------------------------------===//
899// IncludeOp
900//===----------------------------------------------------------------------===//
901
902void IncludeOp::print(OpAsmPrinter &p) {
903 bool standardInclude = getIsStandardInclude();
904
905 p << " ";
906 if (standardInclude)
907 p << "<";
908 p << "\"" << getInclude() << "\"";
909 if (standardInclude)
910 p << ">";
911}
912
913ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
914 bool standardInclude = !parser.parseOptionalLess();
915
916 StringAttr include;
917 OptionalParseResult includeParseResult =
918 parser.parseOptionalAttribute(include, "include", result.attributes);
919 if (!includeParseResult.has_value())
920 return parser.emitError(parser.getNameLoc()) << "expected string attribute";
921
922 if (standardInclude && parser.parseOptionalGreater())
923 return parser.emitError(parser.getNameLoc())
924 << "expected trailing '>' for standard include";
925
926 if (standardInclude)
927 result.addAttribute("is_standard_include",
928 UnitAttr::get(parser.getContext()));
929
930 return success();
931}
932
933//===----------------------------------------------------------------------===//
934// LiteralOp
935//===----------------------------------------------------------------------===//
936
937/// The literal op requires a non-empty value.
938LogicalResult emitc::LiteralOp::verify() {
939 if (getValue().empty())
940 return emitOpError() << "value must not be empty";
941 return success();
942}
943//===----------------------------------------------------------------------===//
944// SubOp
945//===----------------------------------------------------------------------===//
946
947LogicalResult SubOp::verify() {
948 Type lhsType = getLhs().getType();
949 Type rhsType = getRhs().getType();
950 Type resultType = getResult().getType();
951
952 if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
953 return emitOpError("rhs can only be a pointer if lhs is a pointer");
954
955 if (isa<emitc::PointerType>(lhsType) &&
956 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
957 return emitOpError("requires that rhs is an integer, pointer or of opaque "
958 "type if lhs is a pointer");
959
960 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
961 !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
962 return emitOpError("requires that the result is an integer, ptrdiff_t or "
963 "of opaque type if lhs and rhs are pointers");
964 return success();
965}
966
967//===----------------------------------------------------------------------===//
968// VariableOp
969//===----------------------------------------------------------------------===//
970
971LogicalResult emitc::VariableOp::verify() {
972 return verifyInitializationAttribute(getOperation(), getValueAttr());
973}
974
975//===----------------------------------------------------------------------===//
976// YieldOp
977//===----------------------------------------------------------------------===//
978
979LogicalResult emitc::YieldOp::verify() {
980 Value result = getResult();
981 Operation *containingOp = getOperation()->getParentOp();
982
983 if (!isa<DoOp>(containingOp) && result && containingOp->getNumResults() != 1)
984 return emitOpError() << "yields a value not returned by parent";
985
986 if (!isa<DoOp>(containingOp) && !result && containingOp->getNumResults() != 0)
987 return emitOpError() << "does not yield a value to be returned by parent";
988
989 return success();
990}
991
992//===----------------------------------------------------------------------===//
993// SubscriptOp
994//===----------------------------------------------------------------------===//
995
996LogicalResult emitc::SubscriptOp::verify() {
997 // Checks for array operand.
998 if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
999 // Check number of indices.
1000 if (getIndices().size() != (size_t)arrayType.getRank()) {
1001 return emitOpError() << "on array operand requires number of indices ("
1002 << getIndices().size()
1003 << ") to match the rank of the array type ("
1004 << arrayType.getRank() << ")";
1005 }
1006 // Check types of index operands.
1007 for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
1008 Type type = getIndices()[i].getType();
1009 if (!isIntegerIndexOrOpaqueType(type)) {
1010 return emitOpError() << "on array operand requires index operand " << i
1011 << " to be integer-like, but got " << type;
1012 }
1013 }
1014 // Check element type.
1015 Type elementType = arrayType.getElementType();
1016 Type resultType = getType().getValueType();
1017 if (elementType != resultType) {
1018 return emitOpError() << "on array operand requires element type ("
1019 << elementType << ") and result type (" << resultType
1020 << ") to match";
1021 }
1022 return success();
1023 }
1024
1025 // Checks for pointer operand.
1026 if (auto pointerType =
1027 llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
1028 // Check number of indices.
1029 if (getIndices().size() != 1) {
1030 return emitOpError()
1031 << "on pointer operand requires one index operand, but got "
1032 << getIndices().size();
1033 }
1034 // Check types of index operand.
1035 Type type = getIndices()[0].getType();
1036 if (!isIntegerIndexOrOpaqueType(type)) {
1037 return emitOpError() << "on pointer operand requires index operand to be "
1038 "integer-like, but got "
1039 << type;
1040 }
1041 // Check pointee type.
1042 Type pointeeType = pointerType.getPointee();
1043 Type resultType = getType().getValueType();
1044 if (pointeeType != resultType) {
1045 return emitOpError() << "on pointer operand requires pointee type ("
1046 << pointeeType << ") and result type (" << resultType
1047 << ") to match";
1048 }
1049 return success();
1050 }
1051
1052 // The operand has opaque type, so we can't assume anything about the number
1053 // or types of index operands.
1054 return success();
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// VerbatimOp
1059//===----------------------------------------------------------------------===//
1060
1061LogicalResult emitc::VerbatimOp::verify() {
1062 auto errorCallback = [&]() -> InFlightDiagnostic {
1063 return this->emitOpError();
1064 };
1065 FailureOr<SmallVector<ReplacementItem>> fmt =
1066 ::parseFormatString(getValue(), getFmtArgs(), errorCallback);
1067 if (failed(fmt))
1068 return failure();
1069
1070 size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
1071 return std::holds_alternative<Placeholder>(item);
1072 });
1073
1074 if (numPlaceholders != getFmtArgs().size()) {
1075 return emitOpError()
1076 << "requires operands for each placeholder in the format string";
1077 }
1078 return success();
1079}
1080
1081FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1082 // Error checking is done in verify.
1083 return ::parseFormatString(getValue(), getFmtArgs());
1084}
1085
1086//===----------------------------------------------------------------------===//
1087// EmitC Enums
1088//===----------------------------------------------------------------------===//
1089
1090#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1091
1092//===----------------------------------------------------------------------===//
1093// EmitC Attributes
1094//===----------------------------------------------------------------------===//
1095
1096#define GET_ATTRDEF_CLASSES
1097#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1098
1099//===----------------------------------------------------------------------===//
1100// EmitC Types
1101//===----------------------------------------------------------------------===//
1102
1103#define GET_TYPEDEF_CLASSES
1104#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1105
1106//===----------------------------------------------------------------------===//
1107// ArrayType
1108//===----------------------------------------------------------------------===//
1109
1110Type emitc::ArrayType::parse(AsmParser &parser) {
1111 if (parser.parseLess())
1112 return Type();
1113
1114 SmallVector<int64_t, 4> dimensions;
1115 if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1116 /*withTrailingX=*/true))
1117 return Type();
1118 // Parse the element type.
1119 auto typeLoc = parser.getCurrentLocation();
1120 Type elementType;
1121 if (parser.parseType(elementType))
1122 return Type();
1123
1124 // Check that array is formed from allowed types.
1125 if (!isValidElementType(elementType))
1126 return parser.emitError(typeLoc, "invalid array element type '")
1127 << elementType << "'",
1128 Type();
1129 if (parser.parseGreater())
1130 return Type();
1131 return parser.getChecked<ArrayType>(dimensions, elementType);
1132}
1133
1134void emitc::ArrayType::print(AsmPrinter &printer) const {
1135 printer << "<";
1136 for (int64_t dim : getShape()) {
1137 printer << dim << 'x';
1138 }
1139 printer.printType(getElementType());
1140 printer << ">";
1141}
1142
1143LogicalResult emitc::ArrayType::verify(
1145 ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1146 if (shape.empty())
1147 return emitError() << "shape must not be empty";
1148
1149 for (int64_t dim : shape) {
1150 if (dim < 0)
1151 return emitError() << "dimensions must have non-negative size";
1152 }
1153
1154 if (!elementType)
1155 return emitError() << "element type must not be none";
1156
1157 if (!isValidElementType(elementType))
1158 return emitError() << "invalid array element type";
1159
1160 return success();
1161}
1162
1163emitc::ArrayType
1164emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1165 Type elementType) const {
1166 if (!shape)
1167 return emitc::ArrayType::get(getShape(), elementType);
1168 return emitc::ArrayType::get(*shape, elementType);
1169}
1170
1171//===----------------------------------------------------------------------===//
1172// LValueType
1173//===----------------------------------------------------------------------===//
1174
1175LogicalResult mlir::emitc::LValueType::verify(
1177 mlir::Type value) {
1178 // Check that the wrapped type is valid. This especially forbids nested
1179 // lvalue types.
1180 if (!isSupportedEmitCType(value))
1181 return emitError()
1182 << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1183
1184 if (llvm::isa<emitc::ArrayType>(value))
1185 return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1186
1187 return success();
1188}
1189
1190//===----------------------------------------------------------------------===//
1191// OpaqueType
1192//===----------------------------------------------------------------------===//
1193
1194LogicalResult mlir::emitc::OpaqueType::verify(
1196 llvm::StringRef value) {
1197 if (value.empty()) {
1198 return emitError() << "expected non empty string in !emitc.opaque type";
1199 }
1200 if (value.back() == '*') {
1201 return emitError() << "pointer not allowed as outer type with "
1202 "!emitc.opaque, use !emitc.ptr instead";
1203 }
1204 return success();
1205}
1206
1207//===----------------------------------------------------------------------===//
1208// PointerType
1209//===----------------------------------------------------------------------===//
1210
1211LogicalResult mlir::emitc::PointerType::verify(
1213 if (llvm::isa<emitc::LValueType>(value))
1214 return emitError() << "pointers to lvalues are not allowed";
1215
1216 return success();
1217}
1218
1219//===----------------------------------------------------------------------===//
1220// GlobalOp
1221//===----------------------------------------------------------------------===//
1223 TypeAttr type,
1224 Attribute initialValue) {
1225 p << type;
1226 if (initialValue) {
1227 p << " = ";
1228 p.printAttributeWithoutType(initialValue);
1229 }
1230}
1231
1233 if (auto array = llvm::dyn_cast<ArrayType>(type))
1234 return RankedTensorType::get(array.getShape(), array.getElementType());
1235 return type;
1236}
1237
1238static ParseResult
1240 Attribute &initialValue) {
1241 Type type;
1242 if (parser.parseType(type))
1243 return failure();
1244
1245 typeAttr = TypeAttr::get(type);
1246
1247 if (parser.parseOptionalEqual())
1248 return success();
1249
1250 if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1251 return failure();
1252
1253 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1254 initialValue))
1255 return parser.emitError(parser.getNameLoc())
1256 << "initial value should be a integer, float, elements or opaque "
1257 "attribute";
1258 return success();
1259}
1260
1261LogicalResult GlobalOp::verify() {
1262 if (!isSupportedEmitCType(getType())) {
1263 return emitOpError("expected valid emitc type");
1264 }
1265 if (getInitialValue().has_value()) {
1266 Attribute initValue = getInitialValue().value();
1267 // Check that the type of the initial value is compatible with the type of
1268 // the global variable.
1269 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1270 auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1271 if (!arrayType)
1272 return emitOpError("expected array type, but got ") << getType();
1273
1274 Type initType = elementsAttr.getType();
1276 if (initType != tensorType) {
1277 return emitOpError("initial value expected to be of type ")
1278 << getType() << ", but was of type " << initType;
1279 }
1280 } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1281 if (intAttr.getType() != getType()) {
1282 return emitOpError("initial value expected to be of type ")
1283 << getType() << ", but was of type " << intAttr.getType();
1284 }
1285 } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1286 if (floatAttr.getType() != getType()) {
1287 return emitOpError("initial value expected to be of type ")
1288 << getType() << ", but was of type " << floatAttr.getType();
1289 }
1290 } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1291 return emitOpError("initial value should be a integer, float, elements "
1292 "or opaque attribute, but got ")
1293 << initValue;
1294 }
1295 }
1296 if (getStaticSpecifier() && getExternSpecifier()) {
1297 return emitOpError("cannot have both static and extern specifiers");
1298 }
1299 return success();
1300}
1301
1302//===----------------------------------------------------------------------===//
1303// GetGlobalOp
1304//===----------------------------------------------------------------------===//
1305
1306LogicalResult
1307GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1308 // Verify that the type matches the type of the global variable.
1309 auto global =
1310 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1311 if (!global)
1312 return emitOpError("'")
1313 << getName() << "' does not reference a valid emitc.global";
1314
1315 Type resultType = getResult().getType();
1316 Type globalType = global.getType();
1317
1318 // global has array type
1319 if (llvm::isa<ArrayType>(globalType)) {
1320 if (globalType != resultType)
1321 return emitOpError("on array type expects result type ")
1322 << resultType << " to match type " << globalType
1323 << " of the global @" << getName();
1324 return success();
1325 }
1326
1327 // global has non-array type
1328 auto lvalueType = dyn_cast<LValueType>(resultType);
1329 if (!lvalueType)
1330 return emitOpError("on non-array type expects result type to be an "
1331 "lvalue type for the global @")
1332 << getName();
1333 if (lvalueType.getValueType() != globalType)
1334 return emitOpError("on non-array type expects result inner type ")
1335 << lvalueType.getValueType() << " to match type " << globalType
1336 << " of the global @" << getName();
1337 return success();
1338}
1339
1340//===----------------------------------------------------------------------===//
1341// SwitchOp
1342//===----------------------------------------------------------------------===//
1343
1344/// Parse the case regions and values.
1345static ParseResult
1347 SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1348 SmallVector<int64_t> caseValues;
1349 while (succeeded(parser.parseOptionalKeyword("case"))) {
1350 int64_t value;
1351 Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1352 if (parser.parseInteger(value) ||
1353 parser.parseRegion(region, /*arguments=*/{}))
1354 return failure();
1355 caseValues.push_back(value);
1356 }
1357 cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1358 return success();
1359}
1360
1361/// Print the case regions and values.
1363 DenseI64ArrayAttr cases, RegionRange caseRegions) {
1364 for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1365 p.printNewline();
1366 p << "case " << value << ' ';
1367 p.printRegion(*region, /*printEntryBlockArgs=*/false);
1368 }
1369}
1370
1371static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1372 const Twine &name) {
1373 auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1374 if (!yield)
1375 return op.emitOpError("expected region to end with emitc.yield, but got ")
1376 << region.front().back().getName();
1377
1378 if (yield.getNumOperands() != 0) {
1379 return (op.emitOpError("expected each region to return ")
1380 << "0 values, but " << name << " returns "
1381 << yield.getNumOperands())
1382 .attachNote(yield.getLoc())
1383 << "see yield operation here";
1384 }
1385
1386 return success();
1387}
1388
1389LogicalResult emitc::SwitchOp::verify() {
1390 if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1391 return emitOpError("unsupported type ") << getArg().getType();
1392
1393 if (getCases().size() != getCaseRegions().size()) {
1394 return emitOpError("has ")
1395 << getCaseRegions().size() << " case regions but "
1396 << getCases().size() << " case values";
1397 }
1398
1399 DenseSet<int64_t> valueSet;
1400 for (int64_t value : getCases())
1401 if (!valueSet.insert(value).second)
1402 return emitOpError("has duplicate case value: ") << value;
1403
1404 if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1405 return failure();
1406
1407 for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1408 if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1409 return failure();
1410
1411 return success();
1412}
1413
1414unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1415
1416Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1417
1418Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1419 assert(idx < getNumCases() && "case index out-of-bounds");
1420 return getCaseRegions()[idx].front();
1421}
1422
1423void SwitchOp::getSuccessorRegions(
1425 llvm::append_range(successors, getRegions());
1426}
1427
1428void SwitchOp::getEntrySuccessorRegions(
1429 ArrayRef<Attribute> operands,
1431 FoldAdaptor adaptor(operands, *this);
1432
1433 // If a constant was not provided, all regions are possible successors.
1434 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1435 if (!arg) {
1436 llvm::append_range(successors, getRegions());
1437 return;
1438 }
1439
1440 // Otherwise, try to find a case with a matching value. If not, the
1441 // default region is the only successor.
1442 for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1443 if (caseValue == arg.getInt()) {
1444 successors.emplace_back(&caseRegion);
1445 return;
1446 }
1447 }
1448 successors.emplace_back(&getDefaultRegion());
1449}
1450
1451void SwitchOp::getRegionInvocationBounds(
1453 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1454 if (!operandValue) {
1455 // All regions are invoked at most once.
1456 bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1457 return;
1458 }
1459
1460 unsigned liveIndex = getNumRegions() - 1;
1461 const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1462
1463 liveIndex = iteratorToInt != getCases().end()
1464 ? std::distance(getCases().begin(), iteratorToInt)
1465 : liveIndex;
1466
1467 for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1468 ++regIndex)
1469 bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1470}
1471
1472//===----------------------------------------------------------------------===//
1473// FileOp
1474//===----------------------------------------------------------------------===//
1475void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1476 state.addRegion()->emplaceBlock();
1477 state.attributes.push_back(
1478 builder.getNamedAttr("id", builder.getStringAttr(id)));
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// FieldOp
1483//===----------------------------------------------------------------------===//
1484
1486 TypeAttr type,
1487 Attribute initialValue) {
1488 p << type;
1489 if (initialValue) {
1490 p << " = ";
1491 p.printAttributeWithoutType(initialValue);
1492 }
1493}
1494
1496 if (auto array = llvm::dyn_cast<ArrayType>(type))
1497 return RankedTensorType::get(array.getShape(), array.getElementType());
1498 return type;
1499}
1500
1501static ParseResult
1503 Attribute &initialValue) {
1504 Type type;
1505 if (parser.parseType(type))
1506 return failure();
1507
1508 typeAttr = TypeAttr::get(type);
1509
1510 if (parser.parseOptionalEqual())
1511 return success();
1512
1513 if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
1514 return failure();
1515
1516 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1517 initialValue))
1518 return parser.emitError(parser.getNameLoc())
1519 << "initial value should be a integer, float, elements or opaque "
1520 "attribute";
1521 return success();
1522}
1523
1524LogicalResult FieldOp::verify() {
1526 return emitOpError("expected valid emitc type");
1527
1528 Operation *parentOp = getOperation()->getParentOp();
1529 if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1530 return emitOpError("field must be nested within an emitc.class operation");
1531
1532 StringAttr symName = getSymNameAttr();
1533 if (!symName || symName.getValue().empty())
1534 return emitOpError("field must have a non-empty symbol name");
1535
1536 return success();
1537}
1538
1539//===----------------------------------------------------------------------===//
1540// GetFieldOp
1541//===----------------------------------------------------------------------===//
1542
1543LogicalResult GetFieldOp::verify() {
1544 auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
1545 if (!parentClassOp.getOperation())
1546 return emitOpError(" must be nested within an emitc.class operation");
1547
1548 return success();
1549}
1550
1551LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1552 mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1553 FieldOp fieldOp =
1554 symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
1555 if (!fieldOp)
1556 return emitOpError("field '")
1557 << fieldNameAttr << "' not found in the class";
1558
1559 Type getFieldResultType = getResult().getType();
1560 Type fieldType = fieldOp.getType();
1561
1562 if (fieldType != getFieldResultType)
1563 return emitOpError("result type ")
1564 << getFieldResultType << " does not match field '" << fieldNameAttr
1565 << "' type " << fieldType;
1566
1567 return success();
1568}
1569
1570//===----------------------------------------------------------------------===//
1571// DoOp
1572//===----------------------------------------------------------------------===//
1573
1574void DoOp::print(OpAsmPrinter &p) {
1575 p << ' ';
1576 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
1577 p << " while ";
1578 p.printRegion(getConditionRegion());
1579 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
1580}
1581
1582LogicalResult emitc::DoOp::verify() {
1583 Block &condBlock = getConditionRegion().front();
1584
1585 if (condBlock.getOperations().size() != 2)
1586 return emitOpError(
1587 "condition region must contain exactly two operations: "
1588 "'emitc.expression' followed by 'emitc.yield', but found ")
1589 << condBlock.getOperations().size() << " operations";
1590
1591 Operation &first = condBlock.front();
1592 auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
1593 if (!exprOp)
1594 return emitOpError("expected first op in condition region to be "
1595 "'emitc.expression', but got ")
1596 << first.getName();
1597
1598 if (!exprOp.getResult().getType().isInteger(1))
1599 return emitOpError("emitc.expression in condition region must return "
1600 "'i1', but returns ")
1601 << exprOp.getResult().getType();
1602
1603 Operation &last = condBlock.back();
1604 auto condYield = dyn_cast<emitc::YieldOp>(last);
1605 if (!condYield)
1606 return emitOpError("expected last op in condition region to be "
1607 "'emitc.yield', but got ")
1608 << last.getName();
1609
1610 if (condYield.getNumOperands() != 1)
1611 return emitOpError("expected condition region to return 1 value, but "
1612 "it returns ")
1613 << condYield.getNumOperands() << " values";
1614
1615 if (condYield.getOperand(0) != exprOp.getResult())
1616 return emitError("'emitc.yield' must return result of "
1617 "'emitc.expression' from this condition region");
1618
1619 Block &bodyBlock = getBodyRegion().front();
1620 if (bodyBlock.mightHaveTerminator())
1621 return emitOpError("body region must not contain terminator");
1622
1623 return success();
1624}
1625
1626ParseResult DoOp::parse(OpAsmParser &parser, OperationState &result) {
1627 Region *bodyRegion = result.addRegion();
1628 Region *condRegion = result.addRegion();
1629
1630 if (parser.parseRegion(*bodyRegion) || parser.parseKeyword("while") ||
1631 parser.parseRegion(*condRegion))
1632 return failure();
1633
1634 if (bodyRegion->empty())
1635 bodyRegion->emplaceBlock();
1636
1637 return parser.parseOptionalAttrDictWithKeyword(result.attributes);
1638}
1639
1640//===----------------------------------------------------------------------===//
1641// TableGen'd op method definitions
1642//===----------------------------------------------------------------------===//
1643
1644#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1645
1646#define GET_OP_CLASSES
1647#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static bool hasSideEffects(Operation *op)
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
Definition EmitC.cpp:144
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1371
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition EmitC.cpp:1239
static Type getInitializerTypeForField(Type type)
Definition EmitC.cpp:1495
static ParseResult parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition EmitC.cpp:1502
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, llvm::function_ref< mlir::InFlightDiagnostic()> emitError={})
Parse a format string and return a list of its parts.
Definition EmitC.cpp:179
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition EmitC.cpp:1222
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
Definition EmitC.cpp:1346
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue)
Definition EmitC.cpp:1485
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition EmitC.cpp:1362
static Type getInitializerTypeForGlobal(Type type)
Definition EmitC.cpp:1232
static Type getElementType(Type type)
Determine the element type of type.
b getContext())
static std::string diag(const llvm::Value &value)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
OpListType & getOperations()
Definition Block.h:137
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:250
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition Block.h:212
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
UnitAttr getUnitAttr()
Definition Builders.cpp:98
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:167
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
IndexType getIndexType()
Definition Builders.cpp:51
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void shadowRegionArgs(Region &region, ValueRange namesToUse)=0
Renumber the arguments for the specified region to the same names as the SSA values in namesToUse.
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
This class represents a single result from folding an operation.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
Block & emplaceBlock()
Definition Region.h:46
bool empty()
Definition Region.h:60
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIndex() const
Definition Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4625
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
Definition EmitC.cpp:57
std::variant< StringRef, Placeholder > ReplacementItem
Definition EmitC.h:54
bool isFundamentalType(mlir::Type type)
Determines whether type is a valid fundamental C++ type in EmitC.
Definition EmitC.cpp:136
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition EmitC.cpp:116
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition EmitC.cpp:61
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition EmitC.cpp:131
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
Definition EmitC.cpp:111
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition EmitC.cpp:95
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.