MLIR  21.0.0git
EmitC.cpp
Go to the documentation of this file.
1 //===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Types.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Casting.h"
22 #include "llvm/Support/FormatVariadic.h"
23 
24 using namespace mlir;
25 using namespace mlir::emitc;
26 
27 #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
28 
29 //===----------------------------------------------------------------------===//
30 // EmitCDialect
31 //===----------------------------------------------------------------------===//
32 
33 void EmitCDialect::initialize() {
34  addOperations<
35 #define GET_OP_LIST
36 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
37  >();
38  addTypes<
39 #define GET_TYPEDEF_LIST
40 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
41  >();
42  addAttributes<
43 #define GET_ATTRDEF_LIST
44 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
45  >();
46 }
47 
48 /// Materialize a single constant operation from a given attribute value with
49 /// the desired resultant type.
51  Attribute value, Type type,
52  Location loc) {
53  return builder.create<emitc::ConstantOp>(loc, type, value);
54 }
55 
56 /// Default callback for builders of ops carrying a region. Inserts a yield
57 /// without arguments.
59  builder.create<emitc::YieldOp>(loc);
60 }
61 
63  if (llvm::isa<emitc::OpaqueType>(type))
64  return true;
65  if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
66  return isSupportedEmitCType(ptrType.getPointee());
67  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
68  auto elemType = arrayType.getElementType();
69  return !llvm::isa<emitc::ArrayType>(elemType) &&
70  isSupportedEmitCType(elemType);
71  }
72  if (type.isIndex() || emitc::isPointerWideType(type))
73  return true;
74  if (llvm::isa<IntegerType>(type))
75  return isSupportedIntegerType(type);
76  if (llvm::isa<FloatType>(type))
77  return isSupportedFloatType(type);
78  if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
79  if (!tensorType.hasStaticShape()) {
80  return false;
81  }
82  auto elemType = tensorType.getElementType();
83  if (llvm::isa<emitc::ArrayType>(elemType)) {
84  return false;
85  }
86  return isSupportedEmitCType(elemType);
87  }
88  if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
89  return llvm::all_of(tupleType.getTypes(), [](Type type) {
90  return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
91  });
92  }
93  return false;
94 }
95 
97  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
98  switch (intType.getWidth()) {
99  case 1:
100  case 8:
101  case 16:
102  case 32:
103  case 64:
104  return true;
105  default:
106  return false;
107  }
108  }
109  return false;
110 }
111 
113  return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
115 }
116 
118  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
119  switch (floatType.getWidth()) {
120  case 16: {
121  if (llvm::isa<Float16Type, BFloat16Type>(type))
122  return true;
123  return false;
124  }
125  case 32:
126  case 64:
127  return true;
128  default:
129  return false;
130  }
131  }
132  return false;
133 }
134 
136  return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
137  type);
138 }
139 
140 /// Check that the type of the initial value is compatible with the operations
141 /// result type.
142 static LogicalResult verifyInitializationAttribute(Operation *op,
143  Attribute value) {
144  assert(op->getNumResults() == 1 && "operation must have 1 result");
145 
146  if (llvm::isa<emitc::OpaqueAttr>(value))
147  return success();
148 
149  if (llvm::isa<StringAttr>(value))
150  return op->emitOpError()
151  << "string attributes are not supported, use #emitc.opaque instead";
152 
153  Type resultType = op->getResult(0).getType();
154  if (auto lType = dyn_cast<LValueType>(resultType))
155  resultType = lType.getValueType();
156  Type attrType = cast<TypedAttr>(value).getType();
157 
158  if (isPointerWideType(resultType) && attrType.isIndex())
159  return success();
160 
161  if (resultType != attrType)
162  return op->emitOpError()
163  << "requires attribute to either be an #emitc.opaque attribute or "
164  "it's type ("
165  << attrType << ") to match the op's result type (" << resultType
166  << ")";
167 
168  return success();
169 }
170 
171 /// Parse a format string and return a list of its parts.
172 /// A part is either a StringRef that has to be printed as-is, or
173 /// a Placeholder which requires printing the next operand of the VerbatimOp.
174 /// In the format string, all `{}` are replaced by Placeholders, except if the
175 /// `{` is escaped by `{{` - then it doesn't start a placeholder.
176 template <class ArgType>
177 FailureOr<SmallVector<ReplacementItem>>
178 parseFormatString(StringRef toParse, ArgType fmtArgs,
180  emitError = {}) {
182 
183  // If there are not operands, the format string is not interpreted.
184  if (fmtArgs.empty()) {
185  items.push_back(toParse);
186  return items;
187  }
188 
189  while (!toParse.empty()) {
190  size_t idx = toParse.find('{');
191  if (idx == StringRef::npos) {
192  // No '{'
193  items.push_back(toParse);
194  break;
195  }
196  if (idx > 0) {
197  // Take all chars excluding the '{'.
198  items.push_back(toParse.take_front(idx));
199  toParse = toParse.drop_front(idx);
200  continue;
201  }
202  if (toParse.size() < 2) {
203  return (*emitError)()
204  << "expected '}' after unescaped '{' at end of string";
205  }
206  // toParse contains at least two characters and starts with `{`.
207  char nextChar = toParse[1];
208  if (nextChar == '{') {
209  // Double '{{' -> '{' (escaping).
210  items.push_back(toParse.take_front(1));
211  toParse = toParse.drop_front(2);
212  continue;
213  }
214  if (nextChar == '}') {
215  items.push_back(Placeholder{});
216  toParse = toParse.drop_front(2);
217  continue;
218  }
219 
220  if (emitError.has_value()) {
221  return (*emitError)() << "expected '}' after unescaped '{'";
222  }
223  return failure();
224  }
225  return items;
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // AddOp
230 //===----------------------------------------------------------------------===//
231 
232 LogicalResult AddOp::verify() {
233  Type lhsType = getLhs().getType();
234  Type rhsType = getRhs().getType();
235 
236  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
237  return emitOpError("requires that at most one operand is a pointer");
238 
239  if ((isa<emitc::PointerType>(lhsType) &&
240  !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
241  (isa<emitc::PointerType>(rhsType) &&
242  !isa<IntegerType, emitc::OpaqueType>(lhsType)))
243  return emitOpError("requires that one operand is an integer or of opaque "
244  "type if the other is a pointer");
245 
246  return success();
247 }
248 
249 //===----------------------------------------------------------------------===//
250 // ApplyOp
251 //===----------------------------------------------------------------------===//
252 
253 LogicalResult ApplyOp::verify() {
254  StringRef applicableOperatorStr = getApplicableOperator();
255 
256  // Applicable operator must not be empty.
257  if (applicableOperatorStr.empty())
258  return emitOpError("applicable operator must not be empty");
259 
260  // Only `*` and `&` are supported.
261  if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
262  return emitOpError("applicable operator is illegal");
263 
264  Type operandType = getOperand().getType();
265  Type resultType = getResult().getType();
266  if (applicableOperatorStr == "&") {
267  if (!llvm::isa<emitc::LValueType>(operandType))
268  return emitOpError("operand type must be an lvalue when applying `&`");
269  if (!llvm::isa<emitc::PointerType>(resultType))
270  return emitOpError("result type must be a pointer when applying `&`");
271  } else {
272  if (!llvm::isa<emitc::PointerType>(operandType))
273  return emitOpError("operand type must be a pointer when applying `*`");
274  }
275 
276  return success();
277 }
278 
279 //===----------------------------------------------------------------------===//
280 // AssignOp
281 //===----------------------------------------------------------------------===//
282 
283 /// The assign op requires that the assigned value's type matches the
284 /// assigned-to variable type.
285 LogicalResult emitc::AssignOp::verify() {
287 
288  if (!variable.getDefiningOp())
289  return emitOpError() << "cannot assign to block argument";
290 
291  Type valueType = getValue().getType();
292  Type variableType = variable.getType().getValueType();
293  if (variableType != valueType)
294  return emitOpError() << "requires value's type (" << valueType
295  << ") to match variable's type (" << variableType
296  << ")\n variable: " << variable
297  << "\n value: " << getValue() << "\n";
298  return success();
299 }
300 
301 //===----------------------------------------------------------------------===//
302 // CastOp
303 //===----------------------------------------------------------------------===//
304 
305 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
306  Type input = inputs.front(), output = outputs.front();
307 
308  if (auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
309  if (auto pointerType = dyn_cast<emitc::PointerType>(output)) {
310  return (arrayType.getElementType() == pointerType.getPointee()) &&
311  arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
312  }
313  return false;
314  }
315 
316  return (
318  emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
320  emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
321 }
322 
323 //===----------------------------------------------------------------------===//
324 // CallOpaqueOp
325 //===----------------------------------------------------------------------===//
326 
327 LogicalResult emitc::CallOpaqueOp::verify() {
328  // Callee must not be empty.
329  if (getCallee().empty())
330  return emitOpError("callee must not be empty");
331 
332  if (std::optional<ArrayAttr> argsAttr = getArgs()) {
333  for (Attribute arg : *argsAttr) {
334  auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
335  if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
336  int64_t index = intAttr.getInt();
337  // Args with elements of type index must be in range
338  // [0..operands.size).
339  if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
340  return emitOpError("index argument is out of range");
341 
342  // Args with elements of type ArrayAttr must have a type.
343  } else if (llvm::isa<ArrayAttr>(
344  arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
345  // FIXME: Array attributes never have types
346  return emitOpError("array argument has no type");
347  }
348  }
349  }
350 
351  if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
352  for (Attribute tArg : *templateArgsAttr) {
353  if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
354  return emitOpError("template argument has invalid type");
355  }
356  }
357 
358  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
359  return emitOpError() << "cannot return array type";
360  }
361 
362  return success();
363 }
364 
365 //===----------------------------------------------------------------------===//
366 // ConstantOp
367 //===----------------------------------------------------------------------===//
368 
369 LogicalResult emitc::ConstantOp::verify() {
370  Attribute value = getValueAttr();
371  if (failed(verifyInitializationAttribute(getOperation(), value)))
372  return failure();
373  if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
374  if (opaqueValue.getValue().empty())
375  return emitOpError() << "value must not be empty";
376  }
377  return success();
378 }
379 
380 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
381 
382 //===----------------------------------------------------------------------===//
383 // ExpressionOp
384 //===----------------------------------------------------------------------===//
385 
386 Operation *ExpressionOp::getRootOp() {
387  auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
388  Value yieldedValue = yieldOp.getResult();
389  return yieldedValue.getDefiningOp();
390 }
391 
392 LogicalResult ExpressionOp::verify() {
393  Type resultType = getResult().getType();
394  Region &region = getRegion();
395 
396  Block &body = region.front();
397 
398  if (!body.mightHaveTerminator())
399  return emitOpError("must yield a value at termination");
400 
401  auto yield = cast<YieldOp>(body.getTerminator());
402  Value yieldResult = yield.getResult();
403 
404  if (!yieldResult)
405  return emitOpError("must yield a value at termination");
406 
407  Operation *rootOp = yieldResult.getDefiningOp();
408 
409  if (!rootOp)
410  return emitOpError("yielded value has no defining op");
411 
412  if (rootOp->getParentOp() != getOperation())
413  return emitOpError("yielded value not defined within expression");
414 
415  Type yieldType = yieldResult.getType();
416 
417  if (resultType != yieldType)
418  return emitOpError("requires yielded type to match return type");
419 
420  for (Operation &op : region.front().without_terminator()) {
421  if (!isa<emitc::CExpressionInterface>(op))
422  return emitOpError("contains an unsupported operation");
423  if (op.getNumResults() != 1)
424  return emitOpError("requires exactly one result for each operation");
425  if (!op.getResult(0).hasOneUse())
426  return emitOpError("requires exactly one use for each operation");
427  }
428 
429  return success();
430 }
431 
432 //===----------------------------------------------------------------------===//
433 // ForOp
434 //===----------------------------------------------------------------------===//
435 
436 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
437  Value ub, Value step, BodyBuilderFn bodyBuilder) {
438  OpBuilder::InsertionGuard g(builder);
439  result.addOperands({lb, ub, step});
440  Type t = lb.getType();
441  Region *bodyRegion = result.addRegion();
442  Block *bodyBlock = builder.createBlock(bodyRegion);
443  bodyBlock->addArgument(t, result.location);
444 
445  // Create the default terminator if the builder is not provided.
446  if (!bodyBuilder) {
447  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
448  } else {
449  OpBuilder::InsertionGuard guard(builder);
450  builder.setInsertionPointToStart(bodyBlock);
451  bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
452  }
453 }
454 
455 void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
456 
457 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
458  Builder &builder = parser.getBuilder();
459  Type type;
460 
461  OpAsmParser::Argument inductionVariable;
462  OpAsmParser::UnresolvedOperand lb, ub, step;
463 
464  // Parse the induction variable followed by '='.
465  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
466  // Parse loop bounds.
467  parser.parseOperand(lb) || parser.parseKeyword("to") ||
468  parser.parseOperand(ub) || parser.parseKeyword("step") ||
469  parser.parseOperand(step))
470  return failure();
471 
472  // Parse the optional initial iteration arguments.
474  regionArgs.push_back(inductionVariable);
475 
476  // Parse optional type, else assume Index.
477  if (parser.parseOptionalColon())
478  type = builder.getIndexType();
479  else if (parser.parseType(type))
480  return failure();
481 
482  // Resolve input operands.
483  regionArgs.front().type = type;
484  if (parser.resolveOperand(lb, type, result.operands) ||
485  parser.resolveOperand(ub, type, result.operands) ||
486  parser.resolveOperand(step, type, result.operands))
487  return failure();
488 
489  // Parse the body region.
490  Region *body = result.addRegion();
491  if (parser.parseRegion(*body, regionArgs))
492  return failure();
493 
494  ForOp::ensureTerminator(*body, builder, result.location);
495 
496  // Parse the optional attribute list.
497  if (parser.parseOptionalAttrDict(result.attributes))
498  return failure();
499 
500  return success();
501 }
502 
503 void ForOp::print(OpAsmPrinter &p) {
504  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
505  << getUpperBound() << " step " << getStep();
506 
507  p << ' ';
508  if (Type t = getInductionVar().getType(); !t.isIndex())
509  p << " : " << t << ' ';
510  p.printRegion(getRegion(),
511  /*printEntryBlockArgs=*/false,
512  /*printBlockTerminators=*/false);
513  p.printOptionalAttrDict((*this)->getAttrs());
514 }
515 
516 LogicalResult ForOp::verifyRegions() {
517  // Check that the body defines as single block argument for the induction
518  // variable.
519  if (getInductionVar().getType() != getLowerBound().getType())
520  return emitOpError(
521  "expected induction variable to be same type as bounds and step");
522 
523  return success();
524 }
525 
526 //===----------------------------------------------------------------------===//
527 // CallOp
528 //===----------------------------------------------------------------------===//
529 
530 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
531  // Check that the callee attribute was specified.
532  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
533  if (!fnAttr)
534  return emitOpError("requires a 'callee' symbol reference attribute");
535  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
536  if (!fn)
537  return emitOpError() << "'" << fnAttr.getValue()
538  << "' does not reference a valid function";
539 
540  // Verify that the operand and result types match the callee.
541  auto fnType = fn.getFunctionType();
542  if (fnType.getNumInputs() != getNumOperands())
543  return emitOpError("incorrect number of operands for callee");
544 
545  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
546  if (getOperand(i).getType() != fnType.getInput(i))
547  return emitOpError("operand type mismatch: expected operand type ")
548  << fnType.getInput(i) << ", but provided "
549  << getOperand(i).getType() << " for operand number " << i;
550 
551  if (fnType.getNumResults() != getNumResults())
552  return emitOpError("incorrect number of results for callee");
553 
554  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
555  if (getResult(i).getType() != fnType.getResult(i)) {
556  auto diag = emitOpError("result type mismatch at index ") << i;
557  diag.attachNote() << " op result types: " << getResultTypes();
558  diag.attachNote() << "function result types: " << fnType.getResults();
559  return diag;
560  }
561 
562  return success();
563 }
564 
565 FunctionType CallOp::getCalleeType() {
566  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
567 }
568 
569 //===----------------------------------------------------------------------===//
570 // DeclareFuncOp
571 //===----------------------------------------------------------------------===//
572 
573 LogicalResult
574 DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
575  // Check that the sym_name attribute was specified.
576  auto fnAttr = getSymNameAttr();
577  if (!fnAttr)
578  return emitOpError("requires a 'sym_name' symbol reference attribute");
579  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
580  if (!fn)
581  return emitOpError() << "'" << fnAttr.getValue()
582  << "' does not reference a valid function";
583 
584  return success();
585 }
586 
587 //===----------------------------------------------------------------------===//
588 // FuncOp
589 //===----------------------------------------------------------------------===//
590 
591 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
592  FunctionType type, ArrayRef<NamedAttribute> attrs,
593  ArrayRef<DictionaryAttr> argAttrs) {
594  state.addAttribute(SymbolTable::getSymbolAttrName(),
595  builder.getStringAttr(name));
596  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
597  state.attributes.append(attrs.begin(), attrs.end());
598  state.addRegion();
599 
600  if (argAttrs.empty())
601  return;
602  assert(type.getNumInputs() == argAttrs.size());
604  builder, state, argAttrs, /*resultAttrs=*/{},
605  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
606 }
607 
608 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
609  auto buildFuncType =
610  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
612  std::string &) { return builder.getFunctionType(argTypes, results); };
613 
615  parser, result, /*allowVariadic=*/false,
616  getFunctionTypeAttrName(result.name), buildFuncType,
617  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
618 }
619 
620 void FuncOp::print(OpAsmPrinter &p) {
622  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
623  getArgAttrsAttrName(), getResAttrsAttrName());
624 }
625 
626 LogicalResult FuncOp::verify() {
627  if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
628  return emitOpError("cannot have lvalue type as argument");
629  }
630 
631  if (getNumResults() > 1)
632  return emitOpError("requires zero or exactly one result, but has ")
633  << getNumResults();
634 
635  if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
636  return emitOpError("cannot return array type");
637 
638  return success();
639 }
640 
641 //===----------------------------------------------------------------------===//
642 // ReturnOp
643 //===----------------------------------------------------------------------===//
644 
645 LogicalResult ReturnOp::verify() {
646  auto function = cast<FuncOp>((*this)->getParentOp());
647 
648  // The operand number and types must match the function signature.
649  if (getNumOperands() != function.getNumResults())
650  return emitOpError("has ")
651  << getNumOperands() << " operands, but enclosing function (@"
652  << function.getName() << ") returns " << function.getNumResults();
653 
654  if (function.getNumResults() == 1)
655  if (getOperand().getType() != function.getResultTypes()[0])
656  return emitError() << "type of the return operand ("
657  << getOperand().getType()
658  << ") doesn't match function result type ("
659  << function.getResultTypes()[0] << ")"
660  << " in function @" << function.getName();
661  return success();
662 }
663 
664 //===----------------------------------------------------------------------===//
665 // IfOp
666 //===----------------------------------------------------------------------===//
667 
668 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
669  bool addThenBlock, bool addElseBlock) {
670  assert((!addElseBlock || addThenBlock) &&
671  "must not create else block w/o then block");
672  result.addOperands(cond);
673 
674  // Add regions and blocks.
675  OpBuilder::InsertionGuard guard(builder);
676  Region *thenRegion = result.addRegion();
677  if (addThenBlock)
678  builder.createBlock(thenRegion);
679  Region *elseRegion = result.addRegion();
680  if (addElseBlock)
681  builder.createBlock(elseRegion);
682 }
683 
684 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
685  bool withElseRegion) {
686  result.addOperands(cond);
687 
688  // Build then region.
689  OpBuilder::InsertionGuard guard(builder);
690  Region *thenRegion = result.addRegion();
691  builder.createBlock(thenRegion);
692 
693  // Build else region.
694  Region *elseRegion = result.addRegion();
695  if (withElseRegion) {
696  builder.createBlock(elseRegion);
697  }
698 }
699 
700 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
701  function_ref<void(OpBuilder &, Location)> thenBuilder,
702  function_ref<void(OpBuilder &, Location)> elseBuilder) {
703  assert(thenBuilder && "the builder callback for 'then' must be present");
704  result.addOperands(cond);
705 
706  // Build then region.
707  OpBuilder::InsertionGuard guard(builder);
708  Region *thenRegion = result.addRegion();
709  builder.createBlock(thenRegion);
710  thenBuilder(builder, result.location);
711 
712  // Build else region.
713  Region *elseRegion = result.addRegion();
714  if (elseBuilder) {
715  builder.createBlock(elseRegion);
716  elseBuilder(builder, result.location);
717  }
718 }
719 
720 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
721  // Create the regions for 'then'.
722  result.regions.reserve(2);
723  Region *thenRegion = result.addRegion();
724  Region *elseRegion = result.addRegion();
725 
726  Builder &builder = parser.getBuilder();
728  Type i1Type = builder.getIntegerType(1);
729  if (parser.parseOperand(cond) ||
730  parser.resolveOperand(cond, i1Type, result.operands))
731  return failure();
732  // Parse the 'then' region.
733  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
734  return failure();
735  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
736 
737  // If we find an 'else' keyword then parse the 'else' region.
738  if (!parser.parseOptionalKeyword("else")) {
739  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
740  return failure();
741  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
742  }
743 
744  // Parse the optional attribute list.
745  if (parser.parseOptionalAttrDict(result.attributes))
746  return failure();
747  return success();
748 }
749 
750 void IfOp::print(OpAsmPrinter &p) {
751  bool printBlockTerminators = false;
752 
753  p << " " << getCondition();
754  p << ' ';
755  p.printRegion(getThenRegion(),
756  /*printEntryBlockArgs=*/false,
757  /*printBlockTerminators=*/printBlockTerminators);
758 
759  // Print the 'else' regions if it exists and has a block.
760  Region &elseRegion = getElseRegion();
761  if (!elseRegion.empty()) {
762  p << " else ";
763  p.printRegion(elseRegion,
764  /*printEntryBlockArgs=*/false,
765  /*printBlockTerminators=*/printBlockTerminators);
766  }
767 
768  p.printOptionalAttrDict((*this)->getAttrs());
769 }
770 
771 /// Given the region at `index`, or the parent operation if `index` is None,
772 /// return the successor regions. These are the regions that may be selected
773 /// during the flow of control. `operands` is a set of optional attributes
774 /// that correspond to a constant value for each operand, or null if that
775 /// operand is not a constant.
776 void IfOp::getSuccessorRegions(RegionBranchPoint point,
778  // The `then` and the `else` region branch back to the parent operation.
779  if (!point.isParent()) {
780  regions.push_back(RegionSuccessor());
781  return;
782  }
783 
784  regions.push_back(RegionSuccessor(&getThenRegion()));
785 
786  // Don't consider the else region if it is empty.
787  Region *elseRegion = &this->getElseRegion();
788  if (elseRegion->empty())
789  regions.push_back(RegionSuccessor());
790  else
791  regions.push_back(RegionSuccessor(elseRegion));
792 }
793 
794 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
796  FoldAdaptor adaptor(operands, *this);
797  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
798  if (!boolAttr || boolAttr.getValue())
799  regions.emplace_back(&getThenRegion());
800 
801  // If the else region is empty, execution continues after the parent op.
802  if (!boolAttr || !boolAttr.getValue()) {
803  if (!getElseRegion().empty())
804  regions.emplace_back(&getElseRegion());
805  else
806  regions.emplace_back();
807  }
808 }
809 
810 void IfOp::getRegionInvocationBounds(
811  ArrayRef<Attribute> operands,
812  SmallVectorImpl<InvocationBounds> &invocationBounds) {
813  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
814  // If the condition is known, then one region is known to be executed once
815  // and the other zero times.
816  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
817  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
818  } else {
819  // Non-constant condition. Each region may be executed 0 or 1 times.
820  invocationBounds.assign(2, {0, 1});
821  }
822 }
823 
824 //===----------------------------------------------------------------------===//
825 // IncludeOp
826 //===----------------------------------------------------------------------===//
827 
829  bool standardInclude = getIsStandardInclude();
830 
831  p << " ";
832  if (standardInclude)
833  p << "<";
834  p << "\"" << getInclude() << "\"";
835  if (standardInclude)
836  p << ">";
837 }
838 
839 ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
840  bool standardInclude = !parser.parseOptionalLess();
841 
842  StringAttr include;
843  OptionalParseResult includeParseResult =
844  parser.parseOptionalAttribute(include, "include", result.attributes);
845  if (!includeParseResult.has_value())
846  return parser.emitError(parser.getNameLoc()) << "expected string attribute";
847 
848  if (standardInclude && parser.parseOptionalGreater())
849  return parser.emitError(parser.getNameLoc())
850  << "expected trailing '>' for standard include";
851 
852  if (standardInclude)
853  result.addAttribute("is_standard_include",
854  UnitAttr::get(parser.getContext()));
855 
856  return success();
857 }
858 
859 //===----------------------------------------------------------------------===//
860 // LiteralOp
861 //===----------------------------------------------------------------------===//
862 
863 /// The literal op requires a non-empty value.
864 LogicalResult emitc::LiteralOp::verify() {
865  if (getValue().empty())
866  return emitOpError() << "value must not be empty";
867  return success();
868 }
869 //===----------------------------------------------------------------------===//
870 // SubOp
871 //===----------------------------------------------------------------------===//
872 
873 LogicalResult SubOp::verify() {
874  Type lhsType = getLhs().getType();
875  Type rhsType = getRhs().getType();
876  Type resultType = getResult().getType();
877 
878  if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
879  return emitOpError("rhs can only be a pointer if lhs is a pointer");
880 
881  if (isa<emitc::PointerType>(lhsType) &&
882  !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
883  return emitOpError("requires that rhs is an integer, pointer or of opaque "
884  "type if lhs is a pointer");
885 
886  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
887  !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
888  return emitOpError("requires that the result is an integer, ptrdiff_t or "
889  "of opaque type if lhs and rhs are pointers");
890  return success();
891 }
892 
893 //===----------------------------------------------------------------------===//
894 // VariableOp
895 //===----------------------------------------------------------------------===//
896 
897 LogicalResult emitc::VariableOp::verify() {
898  return verifyInitializationAttribute(getOperation(), getValueAttr());
899 }
900 
901 //===----------------------------------------------------------------------===//
902 // YieldOp
903 //===----------------------------------------------------------------------===//
904 
905 LogicalResult emitc::YieldOp::verify() {
906  Value result = getResult();
907  Operation *containingOp = getOperation()->getParentOp();
908 
909  if (result && containingOp->getNumResults() != 1)
910  return emitOpError() << "yields a value not returned by parent";
911 
912  if (!result && containingOp->getNumResults() != 0)
913  return emitOpError() << "does not yield a value to be returned by parent";
914 
915  return success();
916 }
917 
918 //===----------------------------------------------------------------------===//
919 // SubscriptOp
920 //===----------------------------------------------------------------------===//
921 
922 LogicalResult emitc::SubscriptOp::verify() {
923  // Checks for array operand.
924  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
925  // Check number of indices.
926  if (getIndices().size() != (size_t)arrayType.getRank()) {
927  return emitOpError() << "on array operand requires number of indices ("
928  << getIndices().size()
929  << ") to match the rank of the array type ("
930  << arrayType.getRank() << ")";
931  }
932  // Check types of index operands.
933  for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
934  Type type = getIndices()[i].getType();
935  if (!isIntegerIndexOrOpaqueType(type)) {
936  return emitOpError() << "on array operand requires index operand " << i
937  << " to be integer-like, but got " << type;
938  }
939  }
940  // Check element type.
941  Type elementType = arrayType.getElementType();
942  Type resultType = getType().getValueType();
943  if (elementType != resultType) {
944  return emitOpError() << "on array operand requires element type ("
945  << elementType << ") and result type (" << resultType
946  << ") to match";
947  }
948  return success();
949  }
950 
951  // Checks for pointer operand.
952  if (auto pointerType =
953  llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
954  // Check number of indices.
955  if (getIndices().size() != 1) {
956  return emitOpError()
957  << "on pointer operand requires one index operand, but got "
958  << getIndices().size();
959  }
960  // Check types of index operand.
961  Type type = getIndices()[0].getType();
962  if (!isIntegerIndexOrOpaqueType(type)) {
963  return emitOpError() << "on pointer operand requires index operand to be "
964  "integer-like, but got "
965  << type;
966  }
967  // Check pointee type.
968  Type pointeeType = pointerType.getPointee();
969  Type resultType = getType().getValueType();
970  if (pointeeType != resultType) {
971  return emitOpError() << "on pointer operand requires pointee type ("
972  << pointeeType << ") and result type (" << resultType
973  << ") to match";
974  }
975  return success();
976  }
977 
978  // The operand has opaque type, so we can't assume anything about the number
979  // or types of index operands.
980  return success();
981 }
982 
983 //===----------------------------------------------------------------------===//
984 // VerbatimOp
985 //===----------------------------------------------------------------------===//
986 
987 LogicalResult emitc::VerbatimOp::verify() {
988  auto errorCallback = [&]() -> InFlightDiagnostic {
989  return this->emitOpError();
990  };
991  FailureOr<SmallVector<ReplacementItem>> fmt =
992  ::parseFormatString(getValue(), getFmtArgs(), errorCallback);
993  if (failed(fmt))
994  return failure();
995 
996  size_t numPlaceholders = llvm::count_if(*fmt, [](ReplacementItem &item) {
997  return std::holds_alternative<Placeholder>(item);
998  });
999 
1000  if (numPlaceholders != getFmtArgs().size()) {
1001  return emitOpError()
1002  << "requires operands for each placeholder in the format string";
1003  }
1004  return success();
1005 }
1006 
1007 FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1008  // Error checking is done in verify.
1009  return ::parseFormatString(getValue(), getFmtArgs());
1010 }
1011 
1012 //===----------------------------------------------------------------------===//
1013 // EmitC Enums
1014 //===----------------------------------------------------------------------===//
1015 
1016 #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1017 
1018 //===----------------------------------------------------------------------===//
1019 // EmitC Attributes
1020 //===----------------------------------------------------------------------===//
1021 
1022 #define GET_ATTRDEF_CLASSES
1023 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1024 
1025 //===----------------------------------------------------------------------===//
1026 // EmitC Types
1027 //===----------------------------------------------------------------------===//
1028 
1029 #define GET_TYPEDEF_CLASSES
1030 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1031 
1032 //===----------------------------------------------------------------------===//
1033 // ArrayType
1034 //===----------------------------------------------------------------------===//
1035 
1037  if (parser.parseLess())
1038  return Type();
1039 
1040  SmallVector<int64_t, 4> dimensions;
1041  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
1042  /*withTrailingX=*/true))
1043  return Type();
1044  // Parse the element type.
1045  auto typeLoc = parser.getCurrentLocation();
1046  Type elementType;
1047  if (parser.parseType(elementType))
1048  return Type();
1049 
1050  // Check that array is formed from allowed types.
1051  if (!isValidElementType(elementType))
1052  return parser.emitError(typeLoc, "invalid array element type '")
1053  << elementType << "'",
1054  Type();
1055  if (parser.parseGreater())
1056  return Type();
1057  return parser.getChecked<ArrayType>(dimensions, elementType);
1058 }
1059 
1060 void emitc::ArrayType::print(AsmPrinter &printer) const {
1061  printer << "<";
1062  for (int64_t dim : getShape()) {
1063  printer << dim << 'x';
1064  }
1065  printer.printType(getElementType());
1066  printer << ">";
1067 }
1068 
1069 LogicalResult emitc::ArrayType::verify(
1071  ::llvm::ArrayRef<int64_t> shape, Type elementType) {
1072  if (shape.empty())
1073  return emitError() << "shape must not be empty";
1074 
1075  for (int64_t dim : shape) {
1076  if (dim < 0)
1077  return emitError() << "dimensions must have non-negative size";
1078  }
1079 
1080  if (!elementType)
1081  return emitError() << "element type must not be none";
1082 
1083  if (!isValidElementType(elementType))
1084  return emitError() << "invalid array element type";
1085 
1086  return success();
1087 }
1088 
1089 emitc::ArrayType
1090 emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
1091  Type elementType) const {
1092  if (!shape)
1093  return emitc::ArrayType::get(getShape(), elementType);
1094  return emitc::ArrayType::get(*shape, elementType);
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // LValueType
1099 //===----------------------------------------------------------------------===//
1100 
1101 LogicalResult mlir::emitc::LValueType::verify(
1103  mlir::Type value) {
1104  // Check that the wrapped type is valid. This especially forbids nested
1105  // lvalue types.
1106  if (!isSupportedEmitCType(value))
1107  return emitError()
1108  << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1109 
1110  if (llvm::isa<emitc::ArrayType>(value))
1111  return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1112 
1113  return success();
1114 }
1115 
1116 //===----------------------------------------------------------------------===//
1117 // OpaqueType
1118 //===----------------------------------------------------------------------===//
1119 
1120 LogicalResult mlir::emitc::OpaqueType::verify(
1122  llvm::StringRef value) {
1123  if (value.empty()) {
1124  return emitError() << "expected non empty string in !emitc.opaque type";
1125  }
1126  if (value.back() == '*') {
1127  return emitError() << "pointer not allowed as outer type with "
1128  "!emitc.opaque, use !emitc.ptr instead";
1129  }
1130  return success();
1131 }
1132 
1133 //===----------------------------------------------------------------------===//
1134 // PointerType
1135 //===----------------------------------------------------------------------===//
1136 
1137 LogicalResult mlir::emitc::PointerType::verify(
1139  if (llvm::isa<emitc::LValueType>(value))
1140  return emitError() << "pointers to lvalues are not allowed";
1141 
1142  return success();
1143 }
1144 
1145 //===----------------------------------------------------------------------===//
1146 // GlobalOp
1147 //===----------------------------------------------------------------------===//
1149  TypeAttr type,
1150  Attribute initialValue) {
1151  p << type;
1152  if (initialValue) {
1153  p << " = ";
1154  p.printAttributeWithoutType(initialValue);
1155  }
1156 }
1157 
1159  if (auto array = llvm::dyn_cast<ArrayType>(type))
1160  return RankedTensorType::get(array.getShape(), array.getElementType());
1161  return type;
1162 }
1163 
1164 static ParseResult
1166  Attribute &initialValue) {
1167  Type type;
1168  if (parser.parseType(type))
1169  return failure();
1170 
1171  typeAttr = TypeAttr::get(type);
1172 
1173  if (parser.parseOptionalEqual())
1174  return success();
1175 
1176  if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1177  return failure();
1178 
1179  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1180  initialValue))
1181  return parser.emitError(parser.getNameLoc())
1182  << "initial value should be a integer, float, elements or opaque "
1183  "attribute";
1184  return success();
1185 }
1186 
1187 LogicalResult GlobalOp::verify() {
1188  if (!isSupportedEmitCType(getType())) {
1189  return emitOpError("expected valid emitc type");
1190  }
1191  if (getInitialValue().has_value()) {
1192  Attribute initValue = getInitialValue().value();
1193  // Check that the type of the initial value is compatible with the type of
1194  // the global variable.
1195  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1196  auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1197  if (!arrayType)
1198  return emitOpError("expected array type, but got ") << getType();
1199 
1200  Type initType = elementsAttr.getType();
1201  Type tensorType = getInitializerTypeForGlobal(getType());
1202  if (initType != tensorType) {
1203  return emitOpError("initial value expected to be of type ")
1204  << getType() << ", but was of type " << initType;
1205  }
1206  } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1207  if (intAttr.getType() != getType()) {
1208  return emitOpError("initial value expected to be of type ")
1209  << getType() << ", but was of type " << intAttr.getType();
1210  }
1211  } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1212  if (floatAttr.getType() != getType()) {
1213  return emitOpError("initial value expected to be of type ")
1214  << getType() << ", but was of type " << floatAttr.getType();
1215  }
1216  } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1217  return emitOpError("initial value should be a integer, float, elements "
1218  "or opaque attribute, but got ")
1219  << initValue;
1220  }
1221  }
1222  if (getStaticSpecifier() && getExternSpecifier()) {
1223  return emitOpError("cannot have both static and extern specifiers");
1224  }
1225  return success();
1226 }
1227 
1228 //===----------------------------------------------------------------------===//
1229 // GetGlobalOp
1230 //===----------------------------------------------------------------------===//
1231 
1232 LogicalResult
1233 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1234  // Verify that the type matches the type of the global variable.
1235  auto global =
1236  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1237  if (!global)
1238  return emitOpError("'")
1239  << getName() << "' does not reference a valid emitc.global";
1240 
1241  Type resultType = getResult().getType();
1242  Type globalType = global.getType();
1243 
1244  // global has array type
1245  if (llvm::isa<ArrayType>(globalType)) {
1246  if (globalType != resultType)
1247  return emitOpError("on array type expects result type ")
1248  << resultType << " to match type " << globalType
1249  << " of the global @" << getName();
1250  return success();
1251  }
1252 
1253  // global has non-array type
1254  auto lvalueType = dyn_cast<LValueType>(resultType);
1255  if (!lvalueType || lvalueType.getValueType() != globalType)
1256  return emitOpError("on non-array type expects result inner type ")
1257  << lvalueType.getValueType() << " to match type " << globalType
1258  << " of the global @" << getName();
1259  return success();
1260 }
1261 
1262 //===----------------------------------------------------------------------===//
1263 // SwitchOp
1264 //===----------------------------------------------------------------------===//
1265 
1266 /// Parse the case regions and values.
1267 static ParseResult
1269  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1270  SmallVector<int64_t> caseValues;
1271  while (succeeded(parser.parseOptionalKeyword("case"))) {
1272  int64_t value;
1273  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1274  if (parser.parseInteger(value) ||
1275  parser.parseRegion(region, /*arguments=*/{}))
1276  return failure();
1277  caseValues.push_back(value);
1278  }
1279  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1280  return success();
1281 }
1282 
1283 /// Print the case regions and values.
1285  DenseI64ArrayAttr cases, RegionRange caseRegions) {
1286  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1287  p.printNewline();
1288  p << "case " << value << ' ';
1289  p.printRegion(*region, /*printEntryBlockArgs=*/false);
1290  }
1291 }
1292 
1293 static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1294  const Twine &name) {
1295  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1296  if (!yield)
1297  return op.emitOpError("expected region to end with emitc.yield, but got ")
1298  << region.front().back().getName();
1299 
1300  if (yield.getNumOperands() != 0) {
1301  return (op.emitOpError("expected each region to return ")
1302  << "0 values, but " << name << " returns "
1303  << yield.getNumOperands())
1304  .attachNote(yield.getLoc())
1305  << "see yield operation here";
1306  }
1307 
1308  return success();
1309 }
1310 
1311 LogicalResult emitc::SwitchOp::verify() {
1312  if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1313  return emitOpError("unsupported type ") << getArg().getType();
1314 
1315  if (getCases().size() != getCaseRegions().size()) {
1316  return emitOpError("has ")
1317  << getCaseRegions().size() << " case regions but "
1318  << getCases().size() << " case values";
1319  }
1320 
1321  DenseSet<int64_t> valueSet;
1322  for (int64_t value : getCases())
1323  if (!valueSet.insert(value).second)
1324  return emitOpError("has duplicate case value: ") << value;
1325 
1326  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1327  return failure();
1328 
1329  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1330  if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1331  return failure();
1332 
1333  return success();
1334 }
1335 
1336 unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1337 
1338 Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1339 
1340 Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1341  assert(idx < getNumCases() && "case index out-of-bounds");
1342  return getCaseRegions()[idx].front();
1343 }
1344 
1345 void SwitchOp::getSuccessorRegions(
1347  llvm::append_range(successors, getRegions());
1348 }
1349 
1350 void SwitchOp::getEntrySuccessorRegions(
1351  ArrayRef<Attribute> operands,
1352  SmallVectorImpl<RegionSuccessor> &successors) {
1353  FoldAdaptor adaptor(operands, *this);
1354 
1355  // If a constant was not provided, all regions are possible successors.
1356  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1357  if (!arg) {
1358  llvm::append_range(successors, getRegions());
1359  return;
1360  }
1361 
1362  // Otherwise, try to find a case with a matching value. If not, the
1363  // default region is the only successor.
1364  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1365  if (caseValue == arg.getInt()) {
1366  successors.emplace_back(&caseRegion);
1367  return;
1368  }
1369  }
1370  successors.emplace_back(&getDefaultRegion());
1371 }
1372 
1373 void SwitchOp::getRegionInvocationBounds(
1375  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1376  if (!operandValue) {
1377  // All regions are invoked at most once.
1378  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1379  return;
1380  }
1381 
1382  unsigned liveIndex = getNumRegions() - 1;
1383  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1384 
1385  liveIndex = iteratorToInt != getCases().end()
1386  ? std::distance(getCases().begin(), iteratorToInt)
1387  : liveIndex;
1388 
1389  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1390  ++regIndex)
1391  bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1392 }
1393 
1394 //===----------------------------------------------------------------------===//
1395 // FileOp
1396 //===----------------------------------------------------------------------===//
1397 void FileOp::build(OpBuilder &builder, OperationState &state, StringRef id) {
1398  state.addRegion()->emplaceBlock();
1399  state.attributes.push_back(
1400  builder.getNamedAttr("id", builder.getStringAttr(id)));
1401 }
1402 
1403 //===----------------------------------------------------------------------===//
1404 // FieldOp
1405 //===----------------------------------------------------------------------===//
1406 LogicalResult FieldOp::verify() {
1407  if (!isSupportedEmitCType(getType()))
1408  return emitOpError("expected valid emitc type");
1409 
1410  Operation *parentOp = getOperation()->getParentOp();
1411  if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1412  return emitOpError("field must be nested within an emitc.class operation");
1413 
1414  StringAttr symName = getSymNameAttr();
1415  if (!symName || symName.getValue().empty())
1416  return emitOpError("field must have a non-empty symbol name");
1417 
1418  if (!getAttrs())
1419  return success();
1420 
1421  return success();
1422 }
1423 
1424 //===----------------------------------------------------------------------===//
1425 // GetFieldOp
1426 //===----------------------------------------------------------------------===//
1427 LogicalResult GetFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1428  mlir::FlatSymbolRefAttr fieldNameAttr = getFieldNameAttr();
1429  FieldOp fieldOp =
1430  symbolTable.lookupNearestSymbolFrom<FieldOp>(*this, fieldNameAttr);
1431  if (!fieldOp)
1432  return emitOpError("field '")
1433  << fieldNameAttr << "' not found in the class";
1434 
1435  Type getFieldResultType = getResult().getType();
1436  Type fieldType = fieldOp.getType();
1437 
1438  if (fieldType != getFieldResultType)
1439  return emitOpError("result type ")
1440  << getFieldResultType << " does not match field '" << fieldNameAttr
1441  << "' type " << fieldType;
1442 
1443  return success();
1444 }
1445 
1446 //===----------------------------------------------------------------------===//
1447 // TableGen'd op method definitions
1448 //===----------------------------------------------------------------------===//
1449 
1450 #include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1451 
1452 #define GET_OP_CLASSES
1453 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:756
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:748
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
Definition: EmitC.cpp:142
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1293
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: EmitC.cpp:1165
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: EmitC.cpp:1268
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: EmitC.cpp:1148
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: EmitC.cpp:1284
static Type getInitializerTypeForGlobal(Type type)
Definition: EmitC.cpp:1158
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, std::optional< llvm::function_ref< mlir::InFlightDiagnostic()>> emitError={})
Parse a format string and return a list of its parts.
Definition: EmitC.cpp:178
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:188
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:72
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
Operation & back()
Definition: Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:153
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:250
Operation & front()
Definition: Block.h:153
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:209
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:162
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:75
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:257
IndexType getIndexType()
Definition: Builders.cpp:50
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition: Builders.cpp:89
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:314
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:346
This class helps build Operations.
Definition: Builders.h:205
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:425
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:429
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:452
This class represents a single result from folding an operation.
Definition: OpDefinition.h:271
type_range getType() const
Definition: ValueRange.cpp:32
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:54
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition: OpenACC.cpp:3937
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
std::variant< StringRef, Placeholder > ReplacementItem
Definition: EmitC.h:54
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
Definition: EmitC.cpp:58
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:117
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:62
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:135
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
Definition: EmitC.cpp:112
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:96
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:22
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.