MLIR  21.0.0git
EmitC.cpp
Go to the documentation of this file.
1 //===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Types.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace mlir;
25 using namespace mlir::emitc;
26 
27 #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // EmitCDialect
31 //===----------------------------------------------------------------------===//
32 
33 void EmitCDialect::initialize() {
34  addOperations<
35 #define GET_OP_LIST
36 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
37  >();
38  addTypes<
39 #define GET_TYPEDEF_LIST
40 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
41  >();
42  addAttributes<
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
45  >();
46 }
47 
48 /// Materialize a single constant operation from a given attribute value with
49 /// the desired resultant type.
51  Attribute value, Type type,
52  Location loc) {
53  return builder.create<emitc::ConstantOp>(loc, type, value);
54 }
55 
56 /// Default callback for builders of ops carrying a region. Inserts a yield
57 /// without arguments.
59  builder.create<emitc::YieldOp>(loc);
60 }
61 
63  if (llvm::isa<emitc::OpaqueType>(type))
64  return true;
65  if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
66  return isSupportedEmitCType(ptrType.getPointee());
67  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
68  auto elemType = arrayType.getElementType();
69  return !llvm::isa<emitc::ArrayType>(elemType) &&
70  isSupportedEmitCType(elemType);
71  }
72  if (type.isIndex() || emitc::isPointerWideType(type))
73  return true;
74  if (llvm::isa<IntegerType>(type))
75  return isSupportedIntegerType(type);
76  if (llvm::isa<FloatType>(type))
77  return isSupportedFloatType(type);
78  if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
79  if (!tensorType.hasStaticShape()) {
80  return false;
81  }
82  auto elemType = tensorType.getElementType();
83  if (llvm::isa<emitc::ArrayType>(elemType)) {
84  return false;
85  }
86  return isSupportedEmitCType(elemType);
87  }
88  if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
89  return llvm::all_of(tupleType.getTypes(), [](Type type) {
90  return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
91  });
92  }
93  return false;
94 }
95 
97  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
98  switch (intType.getWidth()) {
99  case 1:
100  case 8:
101  case 16:
102  case 32:
103  case 64:
104  return true;
105  default:
106  return false;
107  }
108  }
109  return false;
110 }
111 
113  return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
115 }
116 
118  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
119  switch (floatType.getWidth()) {
120  case 16: {
121  if (llvm::isa<Float16Type, BFloat16Type>(type))
122  return true;
123  return false;
124  }
125  case 32:
126  case 64:
127  return true;
128  default:
129  return false;
130  }
131  }
132  return false;
133 }
134 
136  return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
137  type);
138 }
139 
140 /// Check that the type of the initial value is compatible with the operations
141 /// result type.
142 static LogicalResult verifyInitializationAttribute(Operation *op,
143  Attribute value) {
144  assert(op->getNumResults() == 1 && "operation must have 1 result");
145 
146  if (llvm::isa<emitc::OpaqueAttr>(value))
147  return success();
148 
149  if (llvm::isa<StringAttr>(value))
150  return op->emitOpError()
151  << "string attributes are not supported, use #emitc.opaque instead";
152 
153  Type resultType = op->getResult(0).getType();
154  if (auto lType = dyn_cast<LValueType>(resultType))
155  resultType = lType.getValueType();
156  Type attrType = cast<TypedAttr>(value).getType();
157 
158  if (isPointerWideType(resultType) && attrType.isIndex())
159  return success();
160 
161  if (resultType != attrType)
162  return op->emitOpError()
163  << "requires attribute to either be an #emitc.opaque attribute or "
164  "it's type ("
165  << attrType << ") to match the op's result type (" << resultType
166  << ")";
167 
168  return success();
169 }
170 
171 /// Parse a format string and return a list of its parts.
172 /// A part is either a StringRef that has to be printed as-is, or
173 /// a Placeholder which requires printing the next operand of the VerbatimOp.
174 /// In the format string, all `{}` are replaced by Placeholders, except if the
175 /// `{` is escaped by `{{` - then it doesn't start a placeholder.
176 template <class ArgType>
177 FailureOr<SmallVector<ReplacementItem>>
178 parseFormatString(StringRef toParse, ArgType fmtArgs,
180  emitError = {}) {
182 
183  // If there are not operands, the format string is not interpreted.
184  if (fmtArgs.empty()) {
185  items.push_back(toParse);
186  return items;
187  }
188 
189  while (!toParse.empty()) {
190  size_t idx = toParse.find('{');
191  if (idx == StringRef::npos) {
192  // No '{'
193  items.push_back(toParse);
194  break;
195  }
196  if (idx > 0) {
197  // Take all chars excluding the '{'.
198  items.push_back(toParse.take_front(idx));
199  toParse = toParse.drop_front(idx);
200  continue;
201  }
202  if (toParse.size() < 2) {
203  return (*emitError)()
204  << "expected '}' after unescaped '{' at end of string";
205  }
206  // toParse contains at least two characters and starts with `{`.
207  char nextChar = toParse[1];
208  if (nextChar == '{') {
209  // Double '{{' -> '{' (escaping).
210  items.push_back(toParse.take_front(1));
211  toParse = toParse.drop_front(2);
212  continue;
213  }
214  if (nextChar == '}') {
215  items.push_back(Placeholder{});
216  toParse = toParse.drop_front(2);
217  continue;
218  }
219 
220  if (emitError.has_value()) {
221  return (*emitError)() << "expected '}' after unescaped '{'";
222  }
223  return failure();
224  }
225  return items;
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // AddOp
230 //===----------------------------------------------------------------------===//
231 
232 LogicalResult AddOp::verify() {
233  Type lhsType = getLhs().getType();
234  Type rhsType = getRhs().getType();
235 
236  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
237  return emitOpError("requires that at most one operand is a pointer");
238 
239  if ((isa<emitc::PointerType>(lhsType) &&
240  !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
241  (isa<emitc::PointerType>(rhsType) &&
242  !isa<IntegerType, emitc::OpaqueType>(lhsType)))
243  return emitOpError("requires that one operand is an integer or of opaque "
244  "type if the other is a pointer");
245 
246  return success();
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // ApplyOp
251 //===----------------------------------------------------------------------===//
252 
253 LogicalResult ApplyOp::verify() {
254  StringRef applicableOperatorStr = getApplicableOperator();
255 
256  // Applicable operator must not be empty.
257  if (applicableOperatorStr.empty())
258  return emitOpError("applicable operator must not be empty");
259 
260  // Only `*` and `&` are supported.
261  if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
262  return emitOpError("applicable operator is illegal");
263 
264  Type operandType = getOperand().getType();
265  Type resultType = getResult().getType();
266  if (applicableOperatorStr == "&") {
267  if (!llvm::isa<emitc::LValueType>(operandType))
268  return emitOpError("operand type must be an lvalue when applying `&`");
269  if (!llvm::isa<emitc::PointerType>(resultType))
270  return emitOpError("result type must be a pointer when applying `&`");
271  } else {
272  if (!llvm::isa<emitc::PointerType>(operandType))
273  return emitOpError("operand type must be a pointer when applying `*`");
274  }
275 
276  return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // AssignOp
281 //===----------------------------------------------------------------------===//
282 
283 /// The assign op requires that the assigned value's type matches the
284 /// assigned-to variable type.
285 LogicalResult emitc::AssignOp::verify() {
287 
288  if (!variable.getDefiningOp())
289  return emitOpError() << "cannot assign to block argument";
290 
291  Type valueType = getValue().getType();
292  Type variableType = variable.getType().getValueType();
293  if (variableType != valueType)
294  return emitOpError() << "requires value's type (" << valueType
295  << ") to match variable's type (" << variableType
296  << ")\n variable: " << variable
297  << "\n value: " << getValue() << "\n";
298  return success();
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // CastOp
303 //===----------------------------------------------------------------------===//
304 
305 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
306  Type input = inputs.front(), output = outputs.front();
307 
308  if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
309  if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
310  return (arrayType.getElementType() == pointerType.getPointee()) &&
311  arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
312  }
313  return false;
314  }
315 
316  return (
318  emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
320  emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // CallOpaqueOp
325 //===----------------------------------------------------------------------===//
326 
327 LogicalResult emitc::CallOpaqueOp::verify() {
328  // Callee must not be empty.
329  if (getCallee().empty())
330  return emitOpError("callee must not be empty");
331 
332  if (std::optional<ArrayAttr> argsAttr = getArgs()) {
333  for (Attribute arg : *argsAttr) {
334  auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
335  if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
336  int64_t index = intAttr.getInt();
337  // Args with elements of type index must be in range
338  // [0..operands.size).
339  if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
340  return emitOpError("index argument is out of range");
341 
342  // Args with elements of type ArrayAttr must have a type.
343  } else if (llvm::isa<ArrayAttr>(
344  arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
345  // FIXME: Array attributes never have types
346  return emitOpError("array argument has no type");
347  }
348  }
349  }
350 
351  if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
352  for (Attribute tArg : *templateArgsAttr) {
353  if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
354  return emitOpError("template argument has invalid type");
355  }
356  }
357 
358  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
359  return emitOpError() << "cannot return array type";
360  }
361 
362  return success();
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // ConstantOp
367 //===----------------------------------------------------------------------===//
368 
369 LogicalResult emitc::ConstantOp::verify() {
370  Attribute value = getValueAttr();
371  if (failed(verifyInitializationAttribute(getOperation(), value)))
372  return failure();
373  if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
374  if (opaqueValue.getValue().empty())
375  return emitOpError() << "value must not be empty";
376  }
377  return success();
378 }
379 
380 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
381 
382 //===----------------------------------------------------------------------===//
383 // ExpressionOp
384 //===----------------------------------------------------------------------===//
385 
386 Operation *ExpressionOp::getRootOp() {
387  auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
388  Value yieldedValue = yieldOp.getResult();
389  Operation *rootOp = yieldedValue.getDefiningOp();
390  assert(rootOp && "Yielded value not defined within expression");
391  return rootOp;
392 }
393 
394 LogicalResult ExpressionOp::verify() {
395  Type resultType = getResult().getType();
396  Region &region = getRegion();
397 
398  Block &body = region.front();
399 
400  if (!body.mightHaveTerminator())
401  return emitOpError("must yield a value at termination");
402 
403  auto yield = cast<YieldOp>(body.getTerminator());
404  Value yieldResult = yield.getResult();
405 
406  if (!yieldResult)
407  return emitOpError("must yield a value at termination");
408 
409  Type yieldType = yieldResult.getType();
410 
411  if (resultType != yieldType)
412  return emitOpError("requires yielded type to match return type");
413 
414  for (Operation &op : region.front().without_terminator()) {
415  if (!op.hasTrait<OpTrait::emitc::CExpression>())
416  return emitOpError("contains an unsupported operation");
417  if (op.getNumResults() != 1)
418  return emitOpError("requires exactly one result for each operation");
419  if (!op.getResult(0).hasOneUse())
420  return emitOpError("requires exactly one use for each operation");
421  }
422 
423  return success();
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // ForOp
428 //===----------------------------------------------------------------------===//
429 
430 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
431  Value ub, Value step, BodyBuilderFn bodyBuilder) {
432  OpBuilder::InsertionGuard g(builder);
433  result.addOperands({lb, ub, step});
434  Type t = lb.getType();
435  Region *bodyRegion = result.addRegion();
436  Block *bodyBlock = builder.createBlock(bodyRegion);
437  bodyBlock->addArgument(t, result.location);
438 
439  // Create the default terminator if the builder is not provided.
440  if (!bodyBuilder) {
441  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
442  } else {
443  OpBuilder::InsertionGuard guard(builder);
444  builder.setInsertionPointToStart(bodyBlock);
445  bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
446  }
447 }
448 
449 void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
450 
451 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
452  Builder &builder = parser.getBuilder();
453  Type type;
454 
455  OpAsmParser::Argument inductionVariable;
456  OpAsmParser::UnresolvedOperand lb, ub, step;
457 
458  // Parse the induction variable followed by '='.
459  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
460  // Parse loop bounds.
461  parser.parseOperand(lb) || parser.parseKeyword("to") ||
462  parser.parseOperand(ub) || parser.parseKeyword("step") ||
463  parser.parseOperand(step))
464  return failure();
465 
466  // Parse the optional initial iteration arguments.
468  regionArgs.push_back(inductionVariable);
469 
470  // Parse optional type, else assume Index.
471  if (parser.parseOptionalColon())
472  type = builder.getIndexType();
473  else if (parser.parseType(type))
474  return failure();
475 
476  // Resolve input operands.
477  regionArgs.front().type = type;
478  if (parser.resolveOperand(lb, type, result.operands) ||
479  parser.resolveOperand(ub, type, result.operands) ||
480  parser.resolveOperand(step, type, result.operands))
481  return failure();
482 
483  // Parse the body region.
484  Region *body = result.addRegion();
485  if (parser.parseRegion(*body, regionArgs))
486  return failure();
487 
488  ForOp::ensureTerminator(*body, builder, result.location);
489 
490  // Parse the optional attribute list.
491  if (parser.parseOptionalAttrDict(result.attributes))
492  return failure();
493 
494  return success();
495 }
496 
497 void ForOp::print(OpAsmPrinter &p) {
498  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
499  << getUpperBound() << " step " << getStep();
500 
501  p << ' ';
502  if (Type t = getInductionVar().getType(); !t.isIndex())
503  p << " : " << t << ' ';
504  p.printRegion(getRegion(),
505  /*printEntryBlockArgs=*/false,
506  /*printBlockTerminators=*/false);
507  p.printOptionalAttrDict((*this)->getAttrs());
508 }
509 
510 LogicalResult ForOp::verifyRegions() {
511  // Check that the body defines as single block argument for the induction
512  // variable.
513  if (getInductionVar().getType() != getLowerBound().getType())
514  return emitOpError(
515  "expected induction variable to be same type as bounds and step");
516 
517  return success();
518 }
519 
520 //===----------------------------------------------------------------------===//
521 // CallOp
522 //===----------------------------------------------------------------------===//
523 
524 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
525  // Check that the callee attribute was specified.
526  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
527  if (!fnAttr)
528  return emitOpError("requires a 'callee' symbol reference attribute");
529  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
530  if (!fn)
531  return emitOpError() << "'" << fnAttr.getValue()
532  << "' does not reference a valid function";
533 
534  // Verify that the operand and result types match the callee.
535  auto fnType = fn.getFunctionType();
536  if (fnType.getNumInputs() != getNumOperands())
537  return emitOpError("incorrect number of operands for callee");
538 
539  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
540  if (getOperand(i).getType() != fnType.getInput(i))
541  return emitOpError("operand type mismatch: expected operand type ")
542  << fnType.getInput(i) << ", but provided "
543  << getOperand(i).getType() << " for operand number " << i;
544 
545  if (fnType.getNumResults() != getNumResults())
546  return emitOpError("incorrect number of results for callee");
547 
548  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
549  if (getResult(i).getType() != fnType.getResult(i)) {
550  auto diag = emitOpError("result type mismatch at index ") << i;
551  diag.attachNote() << " op result types: " << getResultTypes();
552  diag.attachNote() << "function result types: " << fnType.getResults();
553  return diag;
554  }
555 
556  return success();
557 }
558 
559 FunctionType CallOp::getCalleeType() {
560  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
561 }
562 
563 //===----------------------------------------------------------------------===//
564 // DeclareFuncOp
565 //===----------------------------------------------------------------------===//
566 
567 LogicalResult
568 DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
569  // Check that the sym_name attribute was specified.
570  auto fnAttr = getSymNameAttr();
571  if (!fnAttr)
572  return emitOpError("requires a 'sym_name' symbol reference attribute");
573  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
574  if (!fn)
575  return emitOpError() << "'" << fnAttr.getValue()
576  << "' does not reference a valid function";
577 
578  return success();
579 }
580 
581 //===----------------------------------------------------------------------===//
582 // FuncOp
583 //===----------------------------------------------------------------------===//
584 
585 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
586  FunctionType type, ArrayRef<NamedAttribute> attrs,
587  ArrayRef<DictionaryAttr> argAttrs) {
588  state.addAttribute(SymbolTable::getSymbolAttrName(),
589  builder.getStringAttr(name));
590  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
591  state.attributes.append(attrs.begin(), attrs.end());
592  state.addRegion();
593 
594  if (argAttrs.empty())
595  return;
596  assert(type.getNumInputs() == argAttrs.size());
598  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
599  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
600 }
601 
602 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
603  auto buildFuncType =
604  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
606  std::string &) { return builder.getFunctionType(argTypes, results); };
607 
609  parser, result, /*allowVariadic=*/false,
610  getFunctionTypeAttrName(result.name), buildFuncType,
611  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
612 }
613 
614 void FuncOp::print(OpAsmPrinter &p) {
616  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
617  getArgAttrsAttrName(), getResAttrsAttrName());
618 }
619 
620 LogicalResult FuncOp::verify() {
621  if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
622  return emitOpError("cannot have lvalue type as argument");
623  }
624 
625  if (getNumResults() > 1)
626  return emitOpError("requires zero or exactly one result, but has ")
627  << getNumResults();
628 
629  if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
630  return emitOpError("cannot return array type");
631 
632  return success();
633 }
634 
635 //===----------------------------------------------------------------------===//
636 // ReturnOp
637 //===----------------------------------------------------------------------===//
638 
639 LogicalResult ReturnOp::verify() {
640  auto function = cast<FuncOp>((*this)->getParentOp());
641 
642  // The operand number and types must match the function signature.
643  if (getNumOperands() != function.getNumResults())
644  return emitOpError("has ")
645  << getNumOperands() << " operands, but enclosing function (@"
646  << function.getName() << ") returns " << function.getNumResults();
647 
648  if (function.getNumResults() == 1)
649  if (getOperand().getType() != function.getResultTypes()[0])
650  return emitError() << "type of the return operand ("
651  << getOperand().getType()
652  << ") doesn't match function result type ("
653  << function.getResultTypes()[0] << ")"
654  << " in function @" << function.getName();
655  return success();
656 }
657 
658 //===----------------------------------------------------------------------===//
659 // IfOp
660 //===----------------------------------------------------------------------===//
661 
662 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
663  bool addThenBlock, bool addElseBlock) {
664  assert((!addElseBlock || addThenBlock) &&
665  "must not create else block w/o then block");
666  result.addOperands(cond);
667 
668  // Add regions and blocks.
669  OpBuilder::InsertionGuard guard(builder);
670  Region *thenRegion = result.addRegion();
671  if (addThenBlock)
672  builder.createBlock(thenRegion);
673  Region *elseRegion = result.addRegion();
674  if (addElseBlock)
675  builder.createBlock(elseRegion);
676 }
677 
678 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
679  bool withElseRegion) {
680  result.addOperands(cond);
681 
682  // Build then region.
683  OpBuilder::InsertionGuard guard(builder);
684  Region *thenRegion = result.addRegion();
685  builder.createBlock(thenRegion);
686 
687  // Build else region.
688  Region *elseRegion = result.addRegion();
689  if (withElseRegion) {
690  builder.createBlock(elseRegion);
691  }
692 }
693 
694 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
695  function_ref<void(OpBuilder &, Location)> thenBuilder,
696  function_ref<void(OpBuilder &, Location)> elseBuilder) {
697  assert(thenBuilder && "the builder callback for 'then' must be present");
698  result.addOperands(cond);
699 
700  // Build then region.
701  OpBuilder::InsertionGuard guard(builder);
702  Region *thenRegion = result.addRegion();
703  builder.createBlock(thenRegion);
704  thenBuilder(builder, result.location);
705 
706  // Build else region.
707  Region *elseRegion = result.addRegion();
708  if (elseBuilder) {
709  builder.createBlock(elseRegion);
710  elseBuilder(builder, result.location);
711  }
712 }
713 
714 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
715  // Create the regions for 'then'.
716  result.regions.reserve(2);
717  Region *thenRegion = result.addRegion();
718  Region *elseRegion = result.addRegion();
719 
720  Builder &builder = parser.getBuilder();
722  Type i1Type = builder.getIntegerType(1);
723  if (parser.parseOperand(cond) ||
724  parser.resolveOperand(cond, i1Type, result.operands))
725  return failure();
726  // Parse the 'then' region.
727  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
728  return failure();
729  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
730 
731  // If we find an 'else' keyword then parse the 'else' region.
732  if (!parser.parseOptionalKeyword("else")) {
733  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
734  return failure();
735  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
736  }
737 
738  // Parse the optional attribute list.
739  if (parser.parseOptionalAttrDict(result.attributes))
740  return failure();
741  return success();
742 }
743 
744 void IfOp::print(OpAsmPrinter &p) {
745  bool printBlockTerminators = false;
746 
747  p << " " << getCondition();
748  p << ' ';
749  p.printRegion(getThenRegion(),
750  /*printEntryBlockArgs=*/false,
751  /*printBlockTerminators=*/printBlockTerminators);
752 
753  // Print the 'else' regions if it exists and has a block.
754  Region &elseRegion = getElseRegion();
755  if (!elseRegion.empty()) {
756  p << " else ";
757  p.printRegion(elseRegion,
758  /*printEntryBlockArgs=*/false,
759  /*printBlockTerminators=*/printBlockTerminators);
760  }
761 
762  p.printOptionalAttrDict((*this)->getAttrs());
763 }
764 
765 /// Given the region at `index`, or the parent operation if `index` is None,
766 /// return the successor regions. These are the regions that may be selected
767 /// during the flow of control. `operands` is a set of optional attributes
768 /// that correspond to a constant value for each operand, or null if that
769 /// operand is not a constant.
770 void IfOp::getSuccessorRegions(RegionBranchPoint point,
772  // The `then` and the `else` region branch back to the parent operation.
773  if (!point.isParent()) {
774  regions.push_back(RegionSuccessor());
775  return;
776  }
777 
778  regions.push_back(RegionSuccessor(&getThenRegion()));
779 
780  // Don't consider the else region if it is empty.
781  Region *elseRegion = &this->getElseRegion();
782  if (elseRegion->empty())
783  regions.push_back(RegionSuccessor());
784  else
785  regions.push_back(RegionSuccessor(elseRegion));
786 }
787 
788 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
790  FoldAdaptor adaptor(operands, *this);
791  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
792  if (!boolAttr || boolAttr.getValue())
793  regions.emplace_back(&getThenRegion());
794 
795  // If the else region is empty, execution continues after the parent op.
796  if (!boolAttr || !boolAttr.getValue()) {
797  if (!getElseRegion().empty())
798  regions.emplace_back(&getElseRegion());
799  else
800  regions.emplace_back();
801  }
802 }
803 
804 void IfOp::getRegionInvocationBounds(
805  ArrayRef<Attribute> operands,
806  SmallVectorImpl<InvocationBounds> &invocationBounds) {
807  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
808  // If the condition is known, then one region is known to be executed once
809  // and the other zero times.
810  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
811  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
812  } else {
813  // Non-constant condition. Each region may be executed 0 or 1 times.
814  invocationBounds.assign(2, {0, 1});
815  }
816 }
817 
818 //===----------------------------------------------------------------------===//
819 // IncludeOp
820 //===----------------------------------------------------------------------===//
821 
823  bool standardInclude = getIsStandardInclude();
824 
825  p << " ";
826  if (standardInclude)
827  p << "<";
828  p << "\"" << getInclude() << "\"";
829  if (standardInclude)
830  p << ">";
831 }
832 
833 ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
834  bool standardInclude = !parser.parseOptionalLess();
835 
836  StringAttr include;
837  OptionalParseResult includeParseResult =
838  parser.parseOptionalAttribute(include, "include", result.attributes);
839  if (!includeParseResult.has_value())
840  return parser.emitError(parser.getNameLoc()) << "expected string attribute";
841 
842  if (standardInclude && parser.parseOptionalGreater())
843  return parser.emitError(parser.getNameLoc())
844  << "expected trailing '>' for standard include";
845 
846  if (standardInclude)
847  result.addAttribute("is_standard_include",
848  UnitAttr::get(parser.getContext()));
849 
850  return success();
851 }
852 
853 //===----------------------------------------------------------------------===//
854 // LiteralOp
855 //===----------------------------------------------------------------------===//
856 
857 /// The literal op requires a non-empty value.
858 LogicalResult emitc::LiteralOp::verify() {
859  if (getValue().empty())
860  return emitOpError() << "value must not be empty";
861  return success();
862 }
863 //===----------------------------------------------------------------------===//
864 // SubOp
865 //===----------------------------------------------------------------------===//
866 
867 LogicalResult SubOp::verify() {
868  Type lhsType = getLhs().getType();
869  Type rhsType = getRhs().getType();
870  Type resultType = getResult().getType();
871 
872  if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
873  return emitOpError("rhs can only be a pointer if lhs is a pointer");
874 
875  if (isa<emitc::PointerType>(lhsType) &&
876  !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
877  return emitOpError("requires that rhs is an integer, pointer or of opaque "
878  "type if lhs is a pointer");
879 
880  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
881  !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
882  return emitOpError("requires that the result is an integer, ptrdiff_t or "
883  "of opaque type if lhs and rhs are pointers");
884  return success();
885 }
886 
887 //===----------------------------------------------------------------------===//
888 // VariableOp
889 //===----------------------------------------------------------------------===//
890 
891 LogicalResult emitc::VariableOp::verify() {
892  return verifyInitializationAttribute(getOperation(), getValueAttr());
893 }
894 
895 //===----------------------------------------------------------------------===//
896 // YieldOp
897 //===----------------------------------------------------------------------===//
898 
899 LogicalResult emitc::YieldOp::verify() {
900  Value result = getResult();
901  Operation *containingOp = getOperation()->getParentOp();
902 
903  if (result && containingOp->getNumResults() != 1)
904  return emitOpError() << "yields a value not returned by parent";
905 
906  if (!result && containingOp->getNumResults() != 0)
907  return emitOpError() << "does not yield a value to be returned by parent";
908 
909  return success();
910 }
911 
912 //===----------------------------------------------------------------------===//
913 // SubscriptOp
914 //===----------------------------------------------------------------------===//
915 
916 LogicalResult emitc::SubscriptOp::verify() {
917  // Checks for array operand.
918  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
919  // Check number of indices.
920  if (getIndices().size() != (size_t)arrayType.getRank()) {
921  return emitOpError() << "on array operand requires number of indices ("
922  << getIndices().size()
923  << ") to match the rank of the array type ("
924  << arrayType.getRank() << ")";
925  }
926  // Check types of index operands.
927  for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
928  Type type = getIndices()[i].getType();
929  if (!isIntegerIndexOrOpaqueType(type)) {
930  return emitOpError() << "on array operand requires index operand " << i
931  << " to be integer-like, but got " << type;
932  }
933  }
934  // Check element type.
935  Type elementType = arrayType.getElementType();
936  Type resultType = getType().getValueType();
937  if (elementType != resultType) {
938  return emitOpError() << "on array operand requires element type ("
939  << elementType << ") and result type (" << resultType
940  << ") to match";
941  }
942  return success();
943  }
944 
945  // Checks for pointer operand.
946  if (auto pointerType =
947  llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
948  // Check number of indices.
949  if (getIndices().size() != 1) {
950  return emitOpError()
951  << "on pointer operand requires one index operand, but got "
952  << getIndices().size();
953  }
954  // Check types of index operand.
955  Type type = getIndices()[0].getType();
956  if (!isIntegerIndexOrOpaqueType(type)) {
957  return emitOpError() << "on pointer operand requires index operand to be "
958  "integer-like, but got "
959  << type;
960  }
961  // Check pointee type.
962  Type pointeeType = pointerType.getPointee();
963  Type resultType = getType().getValueType();
964  if (pointeeType != resultType) {
965  return emitOpError() << "on pointer operand requires pointee type ("
966  << pointeeType << ") and result type (" << resultType
967  << ") to match";
968  }
969  return success();
970  }
971 
972  // The operand has opaque type, so we can't assume anything about the number
973  // or types of index operands.
974  return success();
975 }
976 
977 //===----------------------------------------------------------------------===//
978 // VerbatimOp
979 //===----------------------------------------------------------------------===//
980 
981 LogicalResult emitc::VerbatimOp::verify() {
982  auto errorCallback = [&]() -> InFlightDiagnostic {
983  return this->emitOpError();
984  };
985  FailureOr<SmallVector<ReplacementItem>> fmt =
986  ::parseFormatString(getValue(), getFmtArgs(), errorCallback);
987  if (failed(fmt))
988  return failure();
989 
990  size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
991  return std::holds_alternative<Placeholder>(item);
992  });
993 
994  if (numPlaceholders != getFmtArgs().size()) {
995  return emitOpError()
996  << "requires operands for each placeholder in the format string";
997  }
998  return success();
999 }
1000 
1001 FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1002  // Error checking is done in verify.
1003  return ::parseFormatString(getValue(), getFmtArgs());
1004 }
1005 
1006 //===----------------------------------------------------------------------===//
1007 // EmitC Enums
1008 //===----------------------------------------------------------------------===//
1009 
1010 #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1011 
1012 //===----------------------------------------------------------------------===//
1013 // EmitC Attributes
1014 //===----------------------------------------------------------------------===//
1015 
1016 #define GET_ATTRDEF_CLASSES
1017 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1018 
1019 //===----------------------------------------------------------------------===//
1020 // EmitC Types
1021 //===----------------------------------------------------------------------===//
1022 
1023 #define GET_TYPEDEF_CLASSES
1024 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1025 
1026 //===----------------------------------------------------------------------===//
1027 // ArrayType
1028 //===----------------------------------------------------------------------===//
1029 
1031  if (parser.parseLess())
1032  return Type();
1033 
1034  SmallVector<int64_t, 4> dimensions;
1035  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1036  /*withTrailingX=*/true))
1037  return Type();
1038  // Parse the element type.
1039  auto typeLoc = parser.getCurrentLocation();
1040  Type elementType;
1041  if (parser.parseType(elementType))
1042  return Type();
1043 
1044  // Check that array is formed from allowed types.
1045  if (!isValidElementType(elementType))
1046  return parser.emitError(typeLoc, "invalid array element type '")
1047  << elementType << "'",
1048  Type();
1049  if (parser.parseGreater())
1050  return Type();
1051  return parser.getChecked<ArrayType>(dimensions, elementType);
1052 }
1053 
1054 void emitc::ArrayType::print(AsmPrinter &printer) const {
1055  printer << "<";
1056  for (int64_t dim : getShape()) {
1057  printer << dim << 'x';
1058  }
1059  printer.printType(getElementType());
1060  printer << ">";
1061 }
1062 
1063 LogicalResult emitc::ArrayType::verify(
1065  ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1066  if (shape.empty())
1067  return emitError() << "shape must not be empty";
1068 
1069  for (int64_t dim : shape) {
1070  if (dim < 0)
1071  return emitError() << "dimensions must have non-negative size";
1072  }
1073 
1074  if (!elementType)
1075  return emitError() << "element type must not be none";
1076 
1077  if (!isValidElementType(elementType))
1078  return emitError() << "invalid array element type";
1079 
1080  return success();
1081 }
1082 
1083 emitc::ArrayType
1084 emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1085  Type elementType) const {
1086  if (!shape)
1087  return emitc::ArrayType::get(getShape(), elementType);
1088  return emitc::ArrayType::get(*shape, elementType);
1089 }
1090 
1091 //===----------------------------------------------------------------------===//
1092 // LValueType
1093 //===----------------------------------------------------------------------===//
1094 
1095 LogicalResult mlir::emitc::LValueType::verify(
1097  mlir::Type value) {
1098  // Check that the wrapped type is valid. This especially forbids nested
1099  // lvalue types.
1100  if (!isSupportedEmitCType(value))
1101  return emitError()
1102  << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1103 
1104  if (llvm::isa<emitc::ArrayType>(value))
1105  return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1106 
1107  return success();
1108 }
1109 
1110 //===----------------------------------------------------------------------===//
1111 // OpaqueType
1112 //===----------------------------------------------------------------------===//
1113 
1114 LogicalResult mlir::emitc::OpaqueType::verify(
1116  llvm::StringRef value) {
1117  if (value.empty()) {
1118  return emitError() << "expected non empty string in !emitc.opaque type";
1119  }
1120  if (value.back() == '*') {
1121  return emitError() << "pointer not allowed as outer type with "
1122  "!emitc.opaque, use !emitc.ptr instead";
1123  }
1124  return success();
1125 }
1126 
1127 //===----------------------------------------------------------------------===//
1128 // PointerType
1129 //===----------------------------------------------------------------------===//
1130 
1131 LogicalResult mlir::emitc::PointerType::verify(
1133  if (llvm::isa<emitc::LValueType>(value))
1134  return emitError() << "pointers to lvalues are not allowed";
1135 
1136  return success();
1137 }
1138 
1139 //===----------------------------------------------------------------------===//
1140 // GlobalOp
1141 //===----------------------------------------------------------------------===//
1143  TypeAttr type,
1144  Attribute initialValue) {
1145  p << type;
1146  if (initialValue) {
1147  p << " = ";
1148  p.printAttributeWithoutType(initialValue);
1149  }
1150 }
1151 
1153  if (auto array = llvm::dyn_cast<ArrayType>(type))
1154  return RankedTensorType::get(array.getShape(), array.getElementType());
1155  return type;
1156 }
1157 
1158 static ParseResult
1160  Attribute &initialValue) {
1161  Type type;
1162  if (parser.parseType(type))
1163  return failure();
1164 
1165  typeAttr = TypeAttr::get(type);
1166 
1167  if (parser.parseOptionalEqual())
1168  return success();
1169 
1170  if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1171  return failure();
1172 
1173  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1174  initialValue))
1175  return parser.emitError(parser.getNameLoc())
1176  << "initial value should be a integer, float, elements or opaque "
1177  "attribute";
1178  return success();
1179 }
1180 
1181 LogicalResult GlobalOp::verify() {
1182  if (!isSupportedEmitCType(getType())) {
1183  return emitOpError("expected valid emitc type");
1184  }
1185  if (getInitialValue().has_value()) {
1186  Attribute initValue = getInitialValue().value();
1187  // Check that the type of the initial value is compatible with the type of
1188  // the global variable.
1189  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1190  auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1191  if (!arrayType)
1192  return emitOpError("expected array type, but got ") << getType();
1193 
1194  Type initType = elementsAttr.getType();
1195  Type tensorType = getInitializerTypeForGlobal(getType());
1196  if (initType != tensorType) {
1197  return emitOpError("initial value expected to be of type ")
1198  << getType() << ", but was of type " << initType;
1199  }
1200  } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1201  if (intAttr.getType() != getType()) {
1202  return emitOpError("initial value expected to be of type ")
1203  << getType() << ", but was of type " << intAttr.getType();
1204  }
1205  } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1206  if (floatAttr.getType() != getType()) {
1207  return emitOpError("initial value expected to be of type ")
1208  << getType() << ", but was of type " << floatAttr.getType();
1209  }
1210  } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1211  return emitOpError("initial value should be a integer, float, elements "
1212  "or opaque attribute, but got ")
1213  << initValue;
1214  }
1215  }
1216  if (getStaticSpecifier() && getExternSpecifier()) {
1217  return emitOpError("cannot have both static and extern specifiers");
1218  }
1219  return success();
1220 }
1221 
1222 //===----------------------------------------------------------------------===//
1223 // GetGlobalOp
1224 //===----------------------------------------------------------------------===//
1225 
1226 LogicalResult
1227 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1228  // Verify that the type matches the type of the global variable.
1229  auto global =
1230  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1231  if (!global)
1232  return emitOpError("'")
1233  << getName() << "' does not reference a valid emitc.global";
1234 
1235  Type resultType = getResult().getType();
1236  Type globalType = global.getType();
1237 
1238  // global has array type
1239  if (llvm::isa<ArrayType>(globalType)) {
1240  if (globalType != resultType)
1241  return emitOpError("on array type expects result type ")
1242  << resultType << " to match type " << globalType
1243  << " of the global @" << getName();
1244  return success();
1245  }
1246 
1247  // global has non-array type
1248  auto lvalueType = dyn_cast<LValueType>(resultType);
1249  if (!lvalueType || lvalueType.getValueType() != globalType)
1250  return emitOpError("on non-array type expects result inner type ")
1251  << lvalueType.getValueType() << " to match type " << globalType
1252  << " of the global @" << getName();
1253  return success();
1254 }
1255 
1256 //===----------------------------------------------------------------------===//
1257 // SwitchOp
1258 //===----------------------------------------------------------------------===//
1259 
1260 /// Parse the case regions and values.
1261 static ParseResult
1263  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1264  SmallVector<int64_t> caseValues;
1265  while (succeeded(parser.parseOptionalKeyword("case"))) {
1266  int64_t value;
1267  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1268  if (parser.parseInteger(value) ||
1269  parser.parseRegion(region, /*arguments=*/{}))
1270  return failure();
1271  caseValues.push_back(value);
1272  }
1273  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1274  return success();
1275 }
1276 
1277 /// Print the case regions and values.
1279  DenseI64ArrayAttr cases, RegionRange caseRegions) {
1280  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1281  p.printNewline();
1282  p << "case " << value << ' ';
1283  p.printRegion(*region, /*printEntryBlockArgs=*/false);
1284  }
1285 }
1286 
1287 static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1288  const Twine &name) {
1289  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1290  if (!yield)
1291  return op.emitOpError("expected region to end with emitc.yield, but got ")
1292  << region.front().back().getName();
1293 
1294  if (yield.getNumOperands() != 0) {
1295  return (op.emitOpError("expected each region to return ")
1296  << "0 values, but " << name << " returns "
1297  << yield.getNumOperands())
1298  .attachNote(yield.getLoc())
1299  << "see yield operation here";
1300  }
1301 
1302  return success();
1303 }
1304 
1305 LogicalResult emitc::SwitchOp::verify() {
1306  if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1307  return emitOpError("unsupported type ") << getArg().getType();
1308 
1309  if (getCases().size() != getCaseRegions().size()) {
1310  return emitOpError("has ")
1311  << getCaseRegions().size() << " case regions but "
1312  << getCases().size() << " case values";
1313  }
1314 
1315  DenseSet<int64_t> valueSet;
1316  for (int64_t value : getCases())
1317  if (!valueSet.insert(value).second)
1318  return emitOpError("has duplicate case value: ") << value;
1319 
1320  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1321  return failure();
1322 
1323  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1324  if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1325  return failure();
1326 
1327  return success();
1328 }
1329 
1330 unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1331 
1332 Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1333 
1334 Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1335  assert(idx < getNumCases() && "case index out-of-bounds");
1336  return getCaseRegions()[idx].front();
1337 }
1338 
1339 void SwitchOp::getSuccessorRegions(
1341  llvm::append_range(successors, getRegions());
1342 }
1343 
1344 void SwitchOp::getEntrySuccessorRegions(
1345  ArrayRef<Attribute> operands,
1346  SmallVectorImpl<RegionSuccessor> &successors) {
1347  FoldAdaptor adaptor(operands, *this);
1348 
1349  // If a constant was not provided, all regions are possible successors.
1350  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1351  if (!arg) {
1352  llvm::append_range(successors, getRegions());
1353  return;
1354  }
1355 
1356  // Otherwise, try to find a case with a matching value. If not, the
1357  // default region is the only successor.
1358  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1359  if (caseValue == arg.getInt()) {
1360  successors.emplace_back(&caseRegion);
1361  return;
1362  }
1363  }
1364  successors.emplace_back(&getDefaultRegion());
1365 }
1366 
1367 void SwitchOp::getRegionInvocationBounds(
1369  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1370  if (!operandValue) {
1371  // All regions are invoked at most once.
1372  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1373  return;
1374  }
1375 
1376  unsigned liveIndex = getNumRegions() - 1;
1377  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1378 
1379  liveIndex = iteratorToInt != getCases().end()
1380  ? std::distance(getCases().begin(), iteratorToInt)
1381  : liveIndex;
1382 
1383  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1384  ++regIndex)
1385  bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1386 }
1387 
1388 //===----------------------------------------------------------------------===//
1389 // FileOp
1390 //===----------------------------------------------------------------------===//
1391 void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1392  state.addRegion()->emplaceBlock();
1393  state.attributes.push_back(
1394  builder.getNamedAttr("id", builder.getStringAttr(id)));
1395 }
1396 
1397 //===----------------------------------------------------------------------===//
1398 // TableGen'd op method definitions
1399 //===----------------------------------------------------------------------===//
1400 
1401 #define GET_OP_CLASSES
1402 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:752
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:744
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
Definition: EmitC.cpp:142
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1287
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: EmitC.cpp:1159
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: EmitC.cpp:1262
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: EmitC.cpp:1142
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: EmitC.cpp:1278
static Type getInitializerTypeForGlobal(Type type)
Definition: EmitC.cpp:1152
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, std::optional< llvm::function_ref< mlir::InFlightDiagnostic()>> emitError={})
Parse a format string and return a list of its parts.
Definition: EmitC.cpp:178
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:246
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:155
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:252
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:163
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:76
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:258
IndexType getIndexType()
Definition: Builders.cpp:51
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:90
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:345
This class helps build Operations.
Definition: Builders.h:204
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:428
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:426
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:673
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3683
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::variant< StringRef, Placeholder > ReplacementItem
Definition: EmitC.h:54
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
Definition: EmitC.cpp:58
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:117
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:62
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:135
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
Definition: EmitC.cpp:112
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:96
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.