MLIR  22.0.0git
EmitC.cpp
Go to the documentation of this file.
1 //===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/Types.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/TypeSwitch.h"
19 #include "llvm/Support/Casting.h"
20 
21 using namespace mlir;
22 using namespace mlir::emitc;
23 
24 #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
25 
26 //===----------------------------------------------------------------------===//
27 // EmitCDialect
28 //===----------------------------------------------------------------------===//
29 
30 void EmitCDialect::initialize() {
31  addOperations<
32 #define GET_OP_LIST
33 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
34  >();
35  addTypes<
36 #define GET_TYPEDEF_LIST
37 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
38  >();
39  addAttributes<
40 #define GET_ATTRDEF_LIST
41 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
42  >();
43 }
44 
45 /// Materialize a single constant operation from a given attribute value with
46 /// the desired resultant type.
48  Attribute value, Type type,
49  Location loc) {
50  return emitc::ConstantOp::create(builder, loc, type, value);
51 }
52 
53 /// Default callback for builders of ops carrying a region. Inserts a yield
54 /// without arguments.
56  emitc::YieldOp::create(builder, loc);
57 }
58 
60  if (llvm::isa<emitc::OpaqueType>(type))
61  return true;
62  if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
63  return isSupportedEmitCType(ptrType.getPointee());
64  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
65  auto elemType = arrayType.getElementType();
66  return !llvm::isa<emitc::ArrayType>(elemType) &&
67  isSupportedEmitCType(elemType);
68  }
69  if (type.isIndex() || emitc::isPointerWideType(type))
70  return true;
71  if (llvm::isa<IntegerType>(type))
72  return isSupportedIntegerType(type);
73  if (llvm::isa<FloatType>(type))
74  return isSupportedFloatType(type);
75  if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
76  if (!tensorType.hasStaticShape()) {
77  return false;
78  }
79  auto elemType = tensorType.getElementType();
80  if (llvm::isa<emitc::ArrayType>(elemType)) {
81  return false;
82  }
83  return isSupportedEmitCType(elemType);
84  }
85  if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
86  return llvm::all_of(tupleType.getTypes(), [](Type type) {
87  return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
88  });
89  }
90  return false;
91 }
92 
94  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
95  switch (intType.getWidth()) {
96  case 1:
97  case 8:
98  case 16:
99  case 32:
100  case 64:
101  return true;
102  default:
103  return false;
104  }
105  }
106  return false;
107 }
108 
110  return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
112 }
113 
115  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
116  switch (floatType.getWidth()) {
117  case 16: {
118  if (llvm::isa<Float16Type, BFloat16Type>(type))
119  return true;
120  return false;
121  }
122  case 32:
123  case 64:
124  return true;
125  default:
126  return false;
127  }
128  }
129  return false;
130 }
131 
133  return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
134  type);
135 }
136 
137 /// Check that the type of the initial value is compatible with the operations
138 /// result type.
139 static LogicalResult verifyInitializationAttribute(Operation *op,
140  Attribute value) {
141  assert(op->getNumResults() == 1 && "operation must have 1 result");
142 
143  if (llvm::isa<emitc::OpaqueAttr>(value))
144  return success();
145 
146  if (llvm::isa<StringAttr>(value))
147  return op->emitOpError()
148  << "string attributes are not supported, use #emitc.opaque instead";
149 
150  Type resultType = op->getResult(0).getType();
151  if (auto lType = dyn_cast<LValueType>(resultType))
152  resultType = lType.getValueType();
153  Type attrType = cast<TypedAttr>(value).getType();
154 
155  if (isPointerWideType(resultType) && attrType.isIndex())
156  return success();
157 
158  if (resultType != attrType)
159  return op->emitOpError()
160  << "requires attribute to either be an #emitc.opaque attribute or "
161  "it's type ("
162  << attrType << ") to match the op's result type (" << resultType
163  << ")";
164 
165  return success();
166 }
167 
168 /// Parse a format string and return a list of its parts.
169 /// A part is either a StringRef that has to be printed as-is, or
170 /// a Placeholder which requires printing the next operand of the VerbatimOp.
171 /// In the format string, all `{}` are replaced by Placeholders, except if the
172 /// `{` is escaped by `{{` - then it doesn't start a placeholder.
173 template <class ArgType>
174 FailureOr<SmallVector<ReplacementItem>> parseFormatString(
175  StringRef toParse, ArgType fmtArgs,
178 
179  // If there are not operands, the format string is not interpreted.
180  if (fmtArgs.empty()) {
181  items.push_back(toParse);
182  return items;
183  }
184 
185  while (!toParse.empty()) {
186  size_t idx = toParse.find('{');
187  if (idx == StringRef::npos) {
188  // No '{'
189  items.push_back(toParse);
190  break;
191  }
192  if (idx > 0) {
193  // Take all chars excluding the '{'.
194  items.push_back(toParse.take_front(idx));
195  toParse = toParse.drop_front(idx);
196  continue;
197  }
198  if (toParse.size() < 2) {
199  return emitError() << "expected '}' after unescaped '{' at end of string";
200  }
201  // toParse contains at least two characters and starts with `{`.
202  char nextChar = toParse[1];
203  if (nextChar == '{') {
204  // Double '{{' -> '{' (escaping).
205  items.push_back(toParse.take_front(1));
206  toParse = toParse.drop_front(2);
207  continue;
208  }
209  if (nextChar == '}') {
210  items.push_back(Placeholder{});
211  toParse = toParse.drop_front(2);
212  continue;
213  }
214 
215  if (emitError) {
216  return emitError() << "expected '}' after unescaped '{'";
217  }
218  return failure();
219  }
220  return items;
221 }
222 
223 //===----------------------------------------------------------------------===//
224 // AddOp
225 //===----------------------------------------------------------------------===//
226 
227 LogicalResult AddOp::verify() {
228  Type lhsType = getLhs().getType();
229  Type rhsType = getRhs().getType();
230 
231  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
232  return emitOpError("requires that at most one operand is a pointer");
233 
234  if ((isa<emitc::PointerType>(lhsType) &&
235  !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
236  (isa<emitc::PointerType>(rhsType) &&
237  !isa<IntegerType, emitc::OpaqueType>(lhsType)))
238  return emitOpError("requires that one operand is an integer or of opaque "
239  "type if the other is a pointer");
240 
241  return success();
242 }
243 
244 //===----------------------------------------------------------------------===//
245 // ApplyOp
246 //===----------------------------------------------------------------------===//
247 
248 LogicalResult ApplyOp::verify() {
249  StringRef applicableOperatorStr = getApplicableOperator();
250 
251  // Applicable operator must not be empty.
252  if (applicableOperatorStr.empty())
253  return emitOpError("applicable operator must not be empty");
254 
255  // Only `*` and `&` are supported.
256  if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
257  return emitOpError("applicable operator is illegal");
258 
259  Type operandType = getOperand().getType();
260  Type resultType = getResult().getType();
261  if (applicableOperatorStr == "&") {
262  if (!llvm::isa<emitc::LValueType>(operandType))
263  return emitOpError("operand type must be an lvalue when applying `&`");
264  if (!llvm::isa<emitc::PointerType>(resultType))
265  return emitOpError("result type must be a pointer when applying `&`");
266  } else {
267  if (!llvm::isa<emitc::PointerType>(operandType))
268  return emitOpError("operand type must be a pointer when applying `*`");
269  }
270 
271  return success();
272 }
273 
274 //===----------------------------------------------------------------------===//
275 // AssignOp
276 //===----------------------------------------------------------------------===//
277 
278 /// The assign op requires that the assigned value's type matches the
279 /// assigned-to variable type.
280 LogicalResult emitc::AssignOp::verify() {
282 
283  if (!variable.getDefiningOp())
284  return emitOpError() << "cannot assign to block argument";
285 
286  Type valueType = getValue().getType();
287  Type variableType = variable.getType().getValueType();
288  if (variableType != valueType)
289  return emitOpError() << "requires value's type (" << valueType
290  << ") to match variable's type (" << variableType
291  << ")\n variable: " << variable
292  << "\n value: " << getValue() << "\n";
293  return success();
294 }
295 
296 //===----------------------------------------------------------------------===//
297 // CastOp
298 //===----------------------------------------------------------------------===//
299 
300 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
301  Type input = inputs.front(), output = outputs.front();
302 
303  if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
304  if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
305  return (arrayType.getElementType() == pointerType.getPointee()) &&
306  arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
307  }
308  return false;
309  }
310 
311  return (
313  emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
315  emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
316 }
317 
318 //===----------------------------------------------------------------------===//
319 // CallOpaqueOp
320 //===----------------------------------------------------------------------===//
321 
322 LogicalResult emitc::CallOpaqueOp::verify() {
323  // Callee must not be empty.
324  if (getCallee().empty())
325  return emitOpError("callee must not be empty");
326 
327  if (std::optional<ArrayAttr> argsAttr = getArgs()) {
328  for (Attribute arg : *argsAttr) {
329  auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
330  if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
331  int64_t index = intAttr.getInt();
332  // Args with elements of type index must be in range
333  // [0..operands.size).
334  if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
335  return emitOpError("index argument is out of range");
336 
337  // Args with elements of type ArrayAttr must have a type.
338  } else if (llvm::isa<ArrayAttr>(
339  arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
340  // FIXME: Array attributes never have types
341  return emitOpError("array argument has no type");
342  }
343  }
344  }
345 
346  if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
347  for (Attribute tArg : *templateArgsAttr) {
348  if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
349  return emitOpError("template argument has invalid type");
350  }
351  }
352 
353  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
354  return emitOpError() << "cannot return array type";
355  }
356 
357  return success();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ConstantOp
362 //===----------------------------------------------------------------------===//
363 
364 LogicalResult emitc::ConstantOp::verify() {
365  Attribute value = getValueAttr();
366  if (failed(verifyInitializationAttribute(getOperation(), value)))
367  return failure();
368  if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
369  if (opaqueValue.getValue().empty())
370  return emitOpError() << "value must not be empty";
371  }
372  return success();
373 }
374 
375 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
376 
377 //===----------------------------------------------------------------------===//
378 // ExpressionOp
379 //===----------------------------------------------------------------------===//
380 
381 Operation *ExpressionOp::getRootOp() {
382  auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
383  Value yieldedValue = yieldOp.getResult();
384  return yieldedValue.getDefiningOp();
385 }
386 
387 LogicalResult ExpressionOp::verify() {
388  Type resultType = getResult().getType();
389  Region &region = getRegion();
390 
391  Block &body = region.front();
392 
393  if (!body.mightHaveTerminator())
394  return emitOpError("must yield a value at termination");
395 
396  auto yield = cast<YieldOp>(body.getTerminator());
397  Value yieldResult = yield.getResult();
398 
399  if (!yieldResult)
400  return emitOpError("must yield a value at termination");
401 
402  Operation *rootOp = yieldResult.getDefiningOp();
403 
404  if (!rootOp)
405  return emitOpError("yielded value has no defining op");
406 
407  if (rootOp->getParentOp() != getOperation())
408  return emitOpError("yielded value not defined within expression");
409 
410  Type yieldType = yieldResult.getType();
411 
412  if (resultType != yieldType)
413  return emitOpError("requires yielded type to match return type");
414 
415  for (Operation &op : region.front().without_terminator()) {
416  if (!isa<emitc::CExpressionInterface>(op))
417  return emitOpError("contains an unsupported operation");
418  if (op.getNumResults() != 1)
419  return emitOpError("requires exactly one result for each operation");
420  if (!op.getResult(0).hasOneUse())
421  return emitOpError("requires exactly one use for each operation");
422  }
423 
424  return success();
425 }
426 
427 //===----------------------------------------------------------------------===//
428 // ForOp
429 //===----------------------------------------------------------------------===//
430 
431 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
432  Value ub, Value step, BodyBuilderFn bodyBuilder) {
433  OpBuilder::InsertionGuard g(builder);
434  result.addOperands({lb, ub, step});
435  Type t = lb.getType();
436  Region *bodyRegion = result.addRegion();
437  Block *bodyBlock = builder.createBlock(bodyRegion);
438  bodyBlock->addArgument(t, result.location);
439 
440  // Create the default terminator if the builder is not provided.
441  if (!bodyBuilder) {
442  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
443  } else {
444  OpBuilder::InsertionGuard guard(builder);
445  builder.setInsertionPointToStart(bodyBlock);
446  bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
447  }
448 }
449 
450 void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
451 
452 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
453  Builder &builder = parser.getBuilder();
454  Type type;
455 
456  OpAsmParser::Argument inductionVariable;
457  OpAsmParser::UnresolvedOperand lb, ub, step;
458 
459  // Parse the induction variable followed by '='.
460  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
461  // Parse loop bounds.
462  parser.parseOperand(lb) || parser.parseKeyword("to") ||
463  parser.parseOperand(ub) || parser.parseKeyword("step") ||
464  parser.parseOperand(step))
465  return failure();
466 
467  // Parse the optional initial iteration arguments.
469  regionArgs.push_back(inductionVariable);
470 
471  // Parse optional type, else assume Index.
472  if (parser.parseOptionalColon())
473  type = builder.getIndexType();
474  else if (parser.parseType(type))
475  return failure();
476 
477  // Resolve input operands.
478  regionArgs.front().type = type;
479  if (parser.resolveOperand(lb, type, result.operands) ||
480  parser.resolveOperand(ub, type, result.operands) ||
481  parser.resolveOperand(step, type, result.operands))
482  return failure();
483 
484  // Parse the body region.
485  Region *body = result.addRegion();
486  if (parser.parseRegion(*body, regionArgs))
487  return failure();
488 
489  ForOp::ensureTerminator(*body, builder, result.location);
490 
491  // Parse the optional attribute list.
492  if (parser.parseOptionalAttrDict(result.attributes))
493  return failure();
494 
495  return success();
496 }
497 
498 void ForOp::print(OpAsmPrinter &p) {
499  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
500  << getUpperBound() << " step " << getStep();
501 
502  p << ' ';
503  if (Type t = getInductionVar().getType(); !t.isIndex())
504  p << " : " << t << ' ';
505  p.printRegion(getRegion(),
506  /*printEntryBlockArgs=*/false,
507  /*printBlockTerminators=*/false);
508  p.printOptionalAttrDict((*this)->getAttrs());
509 }
510 
511 LogicalResult ForOp::verifyRegions() {
512  // Check that the body defines as single block argument for the induction
513  // variable.
514  if (getInductionVar().getType() != getLowerBound().getType())
515  return emitOpError(
516  "expected induction variable to be same type as bounds and step");
517 
518  return success();
519 }
520 
521 //===----------------------------------------------------------------------===//
522 // CallOp
523 //===----------------------------------------------------------------------===//
524 
525 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
526  // Check that the callee attribute was specified.
527  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
528  if (!fnAttr)
529  return emitOpError("requires a 'callee' symbol reference attribute");
530  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
531  if (!fn)
532  return emitOpError() << "'" << fnAttr.getValue()
533  << "' does not reference a valid function";
534 
535  // Verify that the operand and result types match the callee.
536  auto fnType = fn.getFunctionType();
537  if (fnType.getNumInputs() != getNumOperands())
538  return emitOpError("incorrect number of operands for callee");
539 
540  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
541  if (getOperand(i).getType() != fnType.getInput(i))
542  return emitOpError("operand type mismatch: expected operand type ")
543  << fnType.getInput(i) << ", but provided "
544  << getOperand(i).getType() << " for operand number " << i;
545 
546  if (fnType.getNumResults() != getNumResults())
547  return emitOpError("incorrect number of results for callee");
548 
549  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
550  if (getResult(i).getType() != fnType.getResult(i)) {
551  auto diag = emitOpError("result type mismatch at index ") << i;
552  diag.attachNote() << " op result types: " << getResultTypes();
553  diag.attachNote() << "function result types: " << fnType.getResults();
554  return diag;
555  }
556 
557  return success();
558 }
559 
560 FunctionType CallOp::getCalleeType() {
561  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
562 }
563 
564 //===----------------------------------------------------------------------===//
565 // DeclareFuncOp
566 //===----------------------------------------------------------------------===//
567 
568 LogicalResult
569 DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
570  // Check that the sym_name attribute was specified.
571  auto fnAttr = getSymNameAttr();
572  if (!fnAttr)
573  return emitOpError("requires a 'sym_name' symbol reference attribute");
574  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
575  if (!fn)
576  return emitOpError() << "'" << fnAttr.getValue()
577  << "' does not reference a valid function";
578 
579  return success();
580 }
581 
582 //===----------------------------------------------------------------------===//
583 // FuncOp
584 //===----------------------------------------------------------------------===//
585 
586 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
587  FunctionType type, ArrayRef<NamedAttribute> attrs,
588  ArrayRef<DictionaryAttr> argAttrs) {
589  state.addAttribute(SymbolTable::getSymbolAttrName(),
590  builder.getStringAttr(name));
591  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
592  state.attributes.append(attrs.begin(), attrs.end());
593  state.addRegion();
594 
595  if (argAttrs.empty())
596  return;
597  assert(type.getNumInputs() == argAttrs.size());
599  builder, state, argAttrs, /*resultAttrs=*/{},
600  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
601 }
602 
603 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
604  auto buildFuncType =
605  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
607  std::string &) { return builder.getFunctionType(argTypes, results); };
608 
610  parser, result, /*allowVariadic=*/false,
611  getFunctionTypeAttrName(result.name), buildFuncType,
612  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
613 }
614 
615 void FuncOp::print(OpAsmPrinter &p) {
617  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
618  getArgAttrsAttrName(), getResAttrsAttrName());
619 }
620 
621 LogicalResult FuncOp::verify() {
622  if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
623  return emitOpError("cannot have lvalue type as argument");
624  }
625 
626  if (getNumResults() > 1)
627  return emitOpError("requires zero or exactly one result, but has ")
628  << getNumResults();
629 
630  if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
631  return emitOpError("cannot return array type");
632 
633  return success();
634 }
635 
636 //===----------------------------------------------------------------------===//
637 // ReturnOp
638 //===----------------------------------------------------------------------===//
639 
640 LogicalResult ReturnOp::verify() {
641  auto function = cast<FuncOp>((*this)->getParentOp());
642 
643  // The operand number and types must match the function signature.
644  if (getNumOperands() != function.getNumResults())
645  return emitOpError("has ")
646  << getNumOperands() << " operands, but enclosing function (@"
647  << function.getName() << ") returns " << function.getNumResults();
648 
649  if (function.getNumResults() == 1)
650  if (getOperand().getType() != function.getResultTypes()[0])
651  return emitError() << "type of the return operand ("
652  << getOperand().getType()
653  << ") doesn't match function result type ("
654  << function.getResultTypes()[0] << ")"
655  << " in function @" << function.getName();
656  return success();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // IfOp
661 //===----------------------------------------------------------------------===//
662 
663 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
664  bool addThenBlock, bool addElseBlock) {
665  assert((!addElseBlock || addThenBlock) &&
666  "must not create else block w/o then block");
667  result.addOperands(cond);
668 
669  // Add regions and blocks.
670  OpBuilder::InsertionGuard guard(builder);
671  Region *thenRegion = result.addRegion();
672  if (addThenBlock)
673  builder.createBlock(thenRegion);
674  Region *elseRegion = result.addRegion();
675  if (addElseBlock)
676  builder.createBlock(elseRegion);
677 }
678 
679 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
680  bool withElseRegion) {
681  result.addOperands(cond);
682 
683  // Build then region.
684  OpBuilder::InsertionGuard guard(builder);
685  Region *thenRegion = result.addRegion();
686  builder.createBlock(thenRegion);
687 
688  // Build else region.
689  Region *elseRegion = result.addRegion();
690  if (withElseRegion) {
691  builder.createBlock(elseRegion);
692  }
693 }
694 
695 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
696  function_ref<void(OpBuilder &, Location)> thenBuilder,
697  function_ref<void(OpBuilder &, Location)> elseBuilder) {
698  assert(thenBuilder && "the builder callback for 'then' must be present");
699  result.addOperands(cond);
700 
701  // Build then region.
702  OpBuilder::InsertionGuard guard(builder);
703  Region *thenRegion = result.addRegion();
704  builder.createBlock(thenRegion);
705  thenBuilder(builder, result.location);
706 
707  // Build else region.
708  Region *elseRegion = result.addRegion();
709  if (elseBuilder) {
710  builder.createBlock(elseRegion);
711  elseBuilder(builder, result.location);
712  }
713 }
714 
715 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
716  // Create the regions for 'then'.
717  result.regions.reserve(2);
718  Region *thenRegion = result.addRegion();
719  Region *elseRegion = result.addRegion();
720 
721  Builder &builder = parser.getBuilder();
723  Type i1Type = builder.getIntegerType(1);
724  if (parser.parseOperand(cond) ||
725  parser.resolveOperand(cond, i1Type, result.operands))
726  return failure();
727  // Parse the 'then' region.
728  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
729  return failure();
730  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
731 
732  // If we find an 'else' keyword then parse the 'else' region.
733  if (!parser.parseOptionalKeyword("else")) {
734  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
735  return failure();
736  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
737  }
738 
739  // Parse the optional attribute list.
740  if (parser.parseOptionalAttrDict(result.attributes))
741  return failure();
742  return success();
743 }
744 
745 void IfOp::print(OpAsmPrinter &p) {
746  bool printBlockTerminators = false;
747 
748  p << " " << getCondition();
749  p << ' ';
750  p.printRegion(getThenRegion(),
751  /*printEntryBlockArgs=*/false,
752  /*printBlockTerminators=*/printBlockTerminators);
753 
754  // Print the 'else' regions if it exists and has a block.
755  Region &elseRegion = getElseRegion();
756  if (!elseRegion.empty()) {
757  p << " else ";
758  p.printRegion(elseRegion,
759  /*printEntryBlockArgs=*/false,
760  /*printBlockTerminators=*/printBlockTerminators);
761  }
762 
763  p.printOptionalAttrDict((*this)->getAttrs());
764 }
765 
766 /// Given the region at `index`, or the parent operation if `index` is None,
767 /// return the successor regions. These are the regions that may be selected
768 /// during the flow of control. `operands` is a set of optional attributes
769 /// that correspond to a constant value for each operand, or null if that
770 /// operand is not a constant.
771 void IfOp::getSuccessorRegions(RegionBranchPoint point,
773  // The `then` and the `else` region branch back to the parent operation.
774  if (!point.isParent()) {
775  regions.push_back(RegionSuccessor());
776  return;
777  }
778 
779  regions.push_back(RegionSuccessor(&getThenRegion()));
780 
781  // Don't consider the else region if it is empty.
782  Region *elseRegion = &this->getElseRegion();
783  if (elseRegion->empty())
784  regions.push_back(RegionSuccessor());
785  else
786  regions.push_back(RegionSuccessor(elseRegion));
787 }
788 
789 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
791  FoldAdaptor adaptor(operands, *this);
792  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
793  if (!boolAttr || boolAttr.getValue())
794  regions.emplace_back(&getThenRegion());
795 
796  // If the else region is empty, execution continues after the parent op.
797  if (!boolAttr || !boolAttr.getValue()) {
798  if (!getElseRegion().empty())
799  regions.emplace_back(&getElseRegion());
800  else
801  regions.emplace_back();
802  }
803 }
804 
805 void IfOp::getRegionInvocationBounds(
806  ArrayRef<Attribute> operands,
807  SmallVectorImpl<InvocationBounds> &invocationBounds) {
808  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
809  // If the condition is known, then one region is known to be executed once
810  // and the other zero times.
811  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
812  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
813  } else {
814  // Non-constant condition. Each region may be executed 0 or 1 times.
815  invocationBounds.assign(2, {0, 1});
816  }
817 }
818 
819 //===----------------------------------------------------------------------===//
820 // IncludeOp
821 //===----------------------------------------------------------------------===//
822 
824  bool standardInclude = getIsStandardInclude();
825 
826  p << " ";
827  if (standardInclude)
828  p << "<";
829  p << "\"" << getInclude() << "\"";
830  if (standardInclude)
831  p << ">";
832 }
833 
834 ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
835  bool standardInclude = !parser.parseOptionalLess();
836 
837  StringAttr include;
838  OptionalParseResult includeParseResult =
839  parser.parseOptionalAttribute(include, "include", result.attributes);
840  if (!includeParseResult.has_value())
841  return parser.emitError(parser.getNameLoc()) << "expected string attribute";
842 
843  if (standardInclude && parser.parseOptionalGreater())
844  return parser.emitError(parser.getNameLoc())
845  << "expected trailing '>' for standard include";
846 
847  if (standardInclude)
848  result.addAttribute("is_standard_include",
849  UnitAttr::get(parser.getContext()));
850 
851  return success();
852 }
853 
854 //===----------------------------------------------------------------------===//
855 // LiteralOp
856 //===----------------------------------------------------------------------===//
857 
858 /// The literal op requires a non-empty value.
859 LogicalResult emitc::LiteralOp::verify() {
860  if (getValue().empty())
861  return emitOpError() << "value must not be empty";
862  return success();
863 }
864 //===----------------------------------------------------------------------===//
865 // SubOp
866 //===----------------------------------------------------------------------===//
867 
868 LogicalResult SubOp::verify() {
869  Type lhsType = getLhs().getType();
870  Type rhsType = getRhs().getType();
871  Type resultType = getResult().getType();
872 
873  if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
874  return emitOpError("rhs can only be a pointer if lhs is a pointer");
875 
876  if (isa<emitc::PointerType>(lhsType) &&
877  !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
878  return emitOpError("requires that rhs is an integer, pointer or of opaque "
879  "type if lhs is a pointer");
880 
881  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
882  !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
883  return emitOpError("requires that the result is an integer, ptrdiff_t or "
884  "of opaque type if lhs and rhs are pointers");
885  return success();
886 }
887 
888 //===----------------------------------------------------------------------===//
889 // VariableOp
890 //===----------------------------------------------------------------------===//
891 
892 LogicalResult emitc::VariableOp::verify() {
893  return verifyInitializationAttribute(getOperation(), getValueAttr());
894 }
895 
896 //===----------------------------------------------------------------------===//
897 // YieldOp
898 //===----------------------------------------------------------------------===//
899 
900 LogicalResult emitc::YieldOp::verify() {
901  Value result = getResult();
902  Operation *containingOp = getOperation()->getParentOp();
903 
904  if (result && containingOp->getNumResults() != 1)
905  return emitOpError() << "yields a value not returned by parent";
906 
907  if (!result && containingOp->getNumResults() != 0)
908  return emitOpError() << "does not yield a value to be returned by parent";
909 
910  return success();
911 }
912 
913 //===----------------------------------------------------------------------===//
914 // SubscriptOp
915 //===----------------------------------------------------------------------===//
916 
917 LogicalResult emitc::SubscriptOp::verify() {
918  // Checks for array operand.
919  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
920  // Check number of indices.
921  if (getIndices().size() != (size_t)arrayType.getRank()) {
922  return emitOpError() << "on array operand requires number of indices ("
923  << getIndices().size()
924  << ") to match the rank of the array type ("
925  << arrayType.getRank() << ")";
926  }
927  // Check types of index operands.
928  for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
929  Type type = getIndices()[i].getType();
930  if (!isIntegerIndexOrOpaqueType(type)) {
931  return emitOpError() << "on array operand requires index operand " << i
932  << " to be integer-like, but got " << type;
933  }
934  }
935  // Check element type.
936  Type elementType = arrayType.getElementType();
937  Type resultType = getType().getValueType();
938  if (elementType != resultType) {
939  return emitOpError() << "on array operand requires element type ("
940  << elementType << ") and result type (" << resultType
941  << ") to match";
942  }
943  return success();
944  }
945 
946  // Checks for pointer operand.
947  if (auto pointerType =
948  llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
949  // Check number of indices.
950  if (getIndices().size() != 1) {
951  return emitOpError()
952  << "on pointer operand requires one index operand, but got "
953  << getIndices().size();
954  }
955  // Check types of index operand.
956  Type type = getIndices()[0].getType();
957  if (!isIntegerIndexOrOpaqueType(type)) {
958  return emitOpError() << "on pointer operand requires index operand to be "
959  "integer-like, but got "
960  << type;
961  }
962  // Check pointee type.
963  Type pointeeType = pointerType.getPointee();
964  Type resultType = getType().getValueType();
965  if (pointeeType != resultType) {
966  return emitOpError() << "on pointer operand requires pointee type ("
967  << pointeeType << ") and result type (" << resultType
968  << ") to match";
969  }
970  return success();
971  }
972 
973  // The operand has opaque type, so we can't assume anything about the number
974  // or types of index operands.
975  return success();
976 }
977 
978 //===----------------------------------------------------------------------===//
979 // VerbatimOp
980 //===----------------------------------------------------------------------===//
981 
982 LogicalResult emitc::VerbatimOp::verify() {
983  auto errorCallback = [&]() -> InFlightDiagnostic {
984  return this->emitOpError();
985  };
986  FailureOr<SmallVector<ReplacementItem>> fmt =
987  ::parseFormatString(getValue(), getFmtArgs(), errorCallback);
988  if (failed(fmt))
989  return failure();
990 
991  size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
992  return std::holds_alternative<Placeholder>(item);
993  });
994 
995  if (numPlaceholders != getFmtArgs().size()) {
996  return emitOpError()
997  << "requires operands for each placeholder in the format string";
998  }
999  return success();
1000 }
1001 
1002 FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1003  // Error checking is done in verify.
1004  return ::parseFormatString(getValue(), getFmtArgs());
1005 }
1006 
1007 //===----------------------------------------------------------------------===//
1008 // EmitC Enums
1009 //===----------------------------------------------------------------------===//
1010 
1011 #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1012 
1013 //===----------------------------------------------------------------------===//
1014 // EmitC Attributes
1015 //===----------------------------------------------------------------------===//
1016 
1017 #define GET_ATTRDEF_CLASSES
1018 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1019 
1020 //===----------------------------------------------------------------------===//
1021 // EmitC Types
1022 //===----------------------------------------------------------------------===//
1023 
1024 #define GET_TYPEDEF_CLASSES
1025 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1026 
1027 //===----------------------------------------------------------------------===//
1028 // ArrayType
1029 //===----------------------------------------------------------------------===//
1030 
1032  if (parser.parseLess())
1033  return Type();
1034 
1035  SmallVector<int64_t, 4> dimensions;
1036  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1037  /*withTrailingX=*/true))
1038  return Type();
1039  // Parse the element type.
1040  auto typeLoc = parser.getCurrentLocation();
1041  Type elementType;
1042  if (parser.parseType(elementType))
1043  return Type();
1044 
1045  // Check that array is formed from allowed types.
1046  if (!isValidElementType(elementType))
1047  return parser.emitError(typeLoc, "invalid array element type '")
1048  << elementType << "'",
1049  Type();
1050  if (parser.parseGreater())
1051  return Type();
1052  return parser.getChecked<ArrayType>(dimensions, elementType);
1053 }
1054 
1055 void emitc::ArrayType::print(AsmPrinter &printer) const {
1056  printer << "<";
1057  for (int64_t dim : getShape()) {
1058  printer << dim << 'x';
1059  }
1060  printer.printType(getElementType());
1061  printer << ">";
1062 }
1063 
1064 LogicalResult emitc::ArrayType::verify(
1066  ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1067  if (shape.empty())
1068  return emitError() << "shape must not be empty";
1069 
1070  for (int64_t dim : shape) {
1071  if (dim < 0)
1072  return emitError() << "dimensions must have non-negative size";
1073  }
1074 
1075  if (!elementType)
1076  return emitError() << "element type must not be none";
1077 
1078  if (!isValidElementType(elementType))
1079  return emitError() << "invalid array element type";
1080 
1081  return success();
1082 }
1083 
1084 emitc::ArrayType
1085 emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1086  Type elementType) const {
1087  if (!shape)
1088  return emitc::ArrayType::get(getShape(), elementType);
1089  return emitc::ArrayType::get(*shape, elementType);
1090 }
1091 
1092 //===----------------------------------------------------------------------===//
1093 // LValueType
1094 //===----------------------------------------------------------------------===//
1095 
1096 LogicalResult mlir::emitc::LValueType::verify(
1098  mlir::Type value) {
1099  // Check that the wrapped type is valid. This especially forbids nested
1100  // lvalue types.
1101  if (!isSupportedEmitCType(value))
1102  return emitError()
1103  << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1104 
1105  if (llvm::isa<emitc::ArrayType>(value))
1106  return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1107 
1108  return success();
1109 }
1110 
1111 //===----------------------------------------------------------------------===//
1112 // OpaqueType
1113 //===----------------------------------------------------------------------===//
1114 
1115 LogicalResult mlir::emitc::OpaqueType::verify(
1117  llvm::StringRef value) {
1118  if (value.empty()) {
1119  return emitError() << "expected non empty string in !emitc.opaque type";
1120  }
1121  if (value.back() == '*') {
1122  return emitError() << "pointer not allowed as outer type with "
1123  "!emitc.opaque, use !emitc.ptr instead";
1124  }
1125  return success();
1126 }
1127 
1128 //===----------------------------------------------------------------------===//
1129 // PointerType
1130 //===----------------------------------------------------------------------===//
1131 
1132 LogicalResult mlir::emitc::PointerType::verify(
1134  if (llvm::isa<emitc::LValueType>(value))
1135  return emitError() << "pointers to lvalues are not allowed";
1136 
1137  return success();
1138 }
1139 
1140 //===----------------------------------------------------------------------===//
1141 // GlobalOp
1142 //===----------------------------------------------------------------------===//
1144  TypeAttr type,
1145  Attribute initialValue) {
1146  p << type;
1147  if (initialValue) {
1148  p << " = ";
1149  p.printAttributeWithoutType(initialValue);
1150  }
1151 }
1152 
1154  if (auto array = llvm::dyn_cast<ArrayType>(type))
1155  return RankedTensorType::get(array.getShape(), array.getElementType());
1156  return type;
1157 }
1158 
1159 static ParseResult
1161  Attribute &initialValue) {
1162  Type type;
1163  if (parser.parseType(type))
1164  return failure();
1165 
1166  typeAttr = TypeAttr::get(type);
1167 
1168  if (parser.parseOptionalEqual())
1169  return success();
1170 
1171  if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1172  return failure();
1173 
1174  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1175  initialValue))
1176  return parser.emitError(parser.getNameLoc())
1177  << "initial value should be a integer, float, elements or opaque "
1178  "attribute";
1179  return success();
1180 }
1181 
1182 LogicalResult GlobalOp::verify() {
1183  if (!isSupportedEmitCType(getType())) {
1184  return emitOpError("expected valid emitc type");
1185  }
1186  if (getInitialValue().has_value()) {
1187  Attribute initValue = getInitialValue().value();
1188  // Check that the type of the initial value is compatible with the type of
1189  // the global variable.
1190  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1191  auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1192  if (!arrayType)
1193  return emitOpError("expected array type, but got ") << getType();
1194 
1195  Type initType = elementsAttr.getType();
1196  Type tensorType = getInitializerTypeForGlobal(getType());
1197  if (initType != tensorType) {
1198  return emitOpError("initial value expected to be of type ")
1199  << getType() << ", but was of type " << initType;
1200  }
1201  } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1202  if (intAttr.getType() != getType()) {
1203  return emitOpError("initial value expected to be of type ")
1204  << getType() << ", but was of type " << intAttr.getType();
1205  }
1206  } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1207  if (floatAttr.getType() != getType()) {
1208  return emitOpError("initial value expected to be of type ")
1209  << getType() << ", but was of type " << floatAttr.getType();
1210  }
1211  } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1212  return emitOpError("initial value should be a integer, float, elements "
1213  "or opaque attribute, but got ")
1214  << initValue;
1215  }
1216  }
1217  if (getStaticSpecifier() && getExternSpecifier()) {
1218  return emitOpError("cannot have both static and extern specifiers");
1219  }
1220  return success();
1221 }
1222 
1223 //===----------------------------------------------------------------------===//
1224 // GetGlobalOp
1225 //===----------------------------------------------------------------------===//
1226 
1227 LogicalResult
1228 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1229  // Verify that the type matches the type of the global variable.
1230  auto global =
1231  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1232  if (!global)
1233  return emitOpError("'")
1234  << getName() << "' does not reference a valid emitc.global";
1235 
1236  Type resultType = getResult().getType();
1237  Type globalType = global.getType();
1238 
1239  // global has array type
1240  if (llvm::isa<ArrayType>(globalType)) {
1241  if (globalType != resultType)
1242  return emitOpError("on array type expects result type ")
1243  << resultType << " to match type " << globalType
1244  << " of the global @" << getName();
1245  return success();
1246  }
1247 
1248  // global has non-array type
1249  auto lvalueType = dyn_cast<LValueType>(resultType);
1250  if (!lvalueType || lvalueType.getValueType() != globalType)
1251  return emitOpError("on non-array type expects result inner type ")
1252  << lvalueType.getValueType() << " to match type " << globalType
1253  << " of the global @" << getName();
1254  return success();
1255 }
1256 
1257 //===----------------------------------------------------------------------===//
1258 // SwitchOp
1259 //===----------------------------------------------------------------------===//
1260 
1261 /// Parse the case regions and values.
1262 static ParseResult
1264  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1265  SmallVector<int64_t> caseValues;
1266  while (succeeded(parser.parseOptionalKeyword("case"))) {
1267  int64_t value;
1268  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1269  if (parser.parseInteger(value) ||
1270  parser.parseRegion(region, /*arguments=*/{}))
1271  return failure();
1272  caseValues.push_back(value);
1273  }
1274  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1275  return success();
1276 }
1277 
1278 /// Print the case regions and values.
1280  DenseI64ArrayAttr cases, RegionRange caseRegions) {
1281  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1282  p.printNewline();
1283  p << "case " << value << ' ';
1284  p.printRegion(*region, /*printEntryBlockArgs=*/false);
1285  }
1286 }
1287 
1288 static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1289  const Twine &name) {
1290  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1291  if (!yield)
1292  return op.emitOpError("expected region to end with emitc.yield, but got ")
1293  << region.front().back().getName();
1294 
1295  if (yield.getNumOperands() != 0) {
1296  return (op.emitOpError("expected each region to return ")
1297  << "0 values, but " << name << " returns "
1298  << yield.getNumOperands())
1299  .attachNote(yield.getLoc())
1300  << "see yield operation here";
1301  }
1302 
1303  return success();
1304 }
1305 
1306 LogicalResult emitc::SwitchOp::verify() {
1307  if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1308  return emitOpError("unsupported type ") << getArg().getType();
1309 
1310  if (getCases().size() != getCaseRegions().size()) {
1311  return emitOpError("has ")
1312  << getCaseRegions().size() << " case regions but "
1313  << getCases().size() << " case values";
1314  }
1315 
1316  DenseSet<int64_t> valueSet;
1317  for (int64_t value : getCases())
1318  if (!valueSet.insert(value).second)
1319  return emitOpError("has duplicate case value: ") << value;
1320 
1321  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1322  return failure();
1323 
1324  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1325  if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1326  return failure();
1327 
1328  return success();
1329 }
1330 
1331 unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1332 
1333 Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1334 
1335 Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1336  assert(idx < getNumCases() && "case index out-of-bounds");
1337  return getCaseRegions()[idx].front();
1338 }
1339 
1340 void SwitchOp::getSuccessorRegions(
1342  llvm::append_range(successors, getRegions());
1343 }
1344 
1345 void SwitchOp::getEntrySuccessorRegions(
1346  ArrayRef<Attribute> operands,
1347  SmallVectorImpl<RegionSuccessor> &successors) {
1348  FoldAdaptor adaptor(operands, *this);
1349 
1350  // If a constant was not provided, all regions are possible successors.
1351  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1352  if (!arg) {
1353  llvm::append_range(successors, getRegions());
1354  return;
1355  }
1356 
1357  // Otherwise, try to find a case with a matching value. If not, the
1358  // default region is the only successor.
1359  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1360  if (caseValue == arg.getInt()) {
1361  successors.emplace_back(&caseRegion);
1362  return;
1363  }
1364  }
1365  successors.emplace_back(&getDefaultRegion());
1366 }
1367 
1368 void SwitchOp::getRegionInvocationBounds(
1370  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1371  if (!operandValue) {
1372  // All regions are invoked at most once.
1373  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1374  return;
1375  }
1376 
1377  unsigned liveIndex = getNumRegions() - 1;
1378  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1379 
1380  liveIndex = iteratorToInt != getCases().end()
1381  ? std::distance(getCases().begin(), iteratorToInt)
1382  : liveIndex;
1383 
1384  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1385  ++regIndex)
1386  bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1387 }
1388 
1389 //===----------------------------------------------------------------------===//
1390 // FileOp
1391 //===----------------------------------------------------------------------===//
1392 void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1393  state.addRegion()->emplaceBlock();
1394  state.attributes.push_back(
1395  builder.getNamedAttr("id", builder.getStringAttr(id)));
1396 }
1397 
1398 //===----------------------------------------------------------------------===//
1399 // FieldOp
1400 //===----------------------------------------------------------------------===//
1402  TypeAttr type,
1403  Attribute initialValue) {
1404  p << type;
1405  if (initialValue) {
1406  p << " = ";
1407  p.printAttributeWithoutType(initialValue);
1408  }
1409 }
1410 
1412  if (auto array = llvm::dyn_cast<ArrayType>(type))
1413  return RankedTensorType::get(array.getShape(), array.getElementType());
1414  return type;
1415 }
1416 
1417 static ParseResult
1419  Attribute &initialValue) {
1420  Type type;
1421  if (parser.parseType(type))
1422  return failure();
1423 
1424  typeAttr = TypeAttr::get(type);
1425 
1426  if (parser.parseOptionalEqual())
1427  return success();
1428 
1429  if (parser.parseAttribute(initialValue, getInitializerTypeForField(type)))
1430  return failure();
1431 
1432  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1433  initialValue))
1434  return parser.emitError(parser.getNameLoc())
1435  << "initial value should be a integer, float, elements or opaque "
1436  "attribute";
1437  return success();
1438 }
1439 
1440 LogicalResult FieldOp::verify() {
1441  if (!isSupportedEmitCType(getType()))
1442  return emitOpError("expected valid emitc type");
1443 
1444  Operation *parentOp = getOperation()->getParentOp();
1445  if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1446  return emitOpError("field must be nested within an emitc.class operation");
1447 
1448  StringAttr symName = getSymNameAttr();
1449  if (!symName || symName.getValue().empty())
1450  return emitOpError("field must have a non-empty symbol name");
1451 
1452  return success();
1453 }
1454 
1455 //===----------------------------------------------------------------------===//
1456 // GetFieldOp
1457 //===----------------------------------------------------------------------===//
1458 LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1459  mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1460  FieldOp fieldOp =
1461  symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
1462  if (!fieldOp)
1463  return emitOpError("field '")
1464  << fieldNameAttr << "' not found in the class";
1465 
1466  Type getFieldResultType = getResult().getType();
1467  Type fieldType = fieldOp.getType();
1468 
1469  if (fieldType != getFieldResultType)
1470  return emitOpError("result type ")
1471  << getFieldResultType << " does not match field '" << fieldNameAttr
1472  << "' type " << fieldType;
1473 
1474  return success();
1475 }
1476 
1477 //===----------------------------------------------------------------------===//
1478 // TableGen'd op method definitions
1479 //===----------------------------------------------------------------------===//
1480 
1481 #include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1482 
1483 #define GET_OP_CLASSES
1484 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:755
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:747
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
Definition: EmitC.cpp:139
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1288
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: EmitC.cpp:1160
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: EmitC.cpp:1263
static Type getInitializerTypeForField(Type type)
Definition: EmitC.cpp:1411
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, llvm::function_ref< mlir::InFlightDiagnostic()> emitError={})
Parse a format string and return a list of its parts.
Definition: EmitC.cpp:174
static ParseResult parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: EmitC.cpp:1418
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: EmitC.cpp:1143
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue)
Definition: EmitC.cpp:1401
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: EmitC.cpp:1279
static Type getInitializerTypeForGlobal(Type type)
Definition: EmitC.cpp:1153
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
IndexType getIndexType()
Definition: Builders.cpp:50
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:89
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:40
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:50
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:4061
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::variant< StringRef, Placeholder > ReplacementItem
Definition: EmitC.h:54
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
Definition: EmitC.cpp:55
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:114
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:59
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:132
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
Definition: EmitC.cpp:109
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:93
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:21
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.