MLIR  20.0.0git
EmitC.cpp
Go to the documentation of this file.
1 //===- EmitC.cpp - EmitC Dialect ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Builders.h"
13 #include "mlir/IR/BuiltinTypes.h"
15 #include "mlir/IR/IRMapping.h"
16 #include "mlir/IR/Types.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Casting.h"
22 
23 using namespace mlir;
24 using namespace mlir::emitc;
25 
26 #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
27 
28 //===----------------------------------------------------------------------===//
29 // EmitCDialect
30 //===----------------------------------------------------------------------===//
31 
32 void EmitCDialect::initialize() {
33  addOperations<
34 #define GET_OP_LIST
35 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
36  >();
37  addTypes<
38 #define GET_TYPEDEF_LIST
39 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
40  >();
41  addAttributes<
42 #define GET_ATTRDEF_LIST
43 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
44  >();
45 }
46 
47 /// Materialize a single constant operation from a given attribute value with
48 /// the desired resultant type.
50  Attribute value, Type type,
51  Location loc) {
52  return builder.create<emitc::ConstantOp>(loc, type, value);
53 }
54 
55 /// Default callback for builders of ops carrying a region. Inserts a yield
56 /// without arguments.
58  builder.create<emitc::YieldOp>(loc);
59 }
60 
62  if (llvm::isa<emitc::OpaqueType>(type))
63  return true;
64  if (auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
65  return isSupportedEmitCType(ptrType.getPointee());
66  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
67  auto elemType = arrayType.getElementType();
68  return !llvm::isa<emitc::ArrayType>(elemType) &&
69  isSupportedEmitCType(elemType);
70  }
71  if (type.isIndex() || emitc::isPointerWideType(type))
72  return true;
73  if (llvm::isa<IntegerType>(type))
74  return isSupportedIntegerType(type);
75  if (llvm::isa<FloatType>(type))
76  return isSupportedFloatType(type);
77  if (auto tensorType = llvm::dyn_cast<TensorType>(type)) {
78  if (!tensorType.hasStaticShape()) {
79  return false;
80  }
81  auto elemType = tensorType.getElementType();
82  if (llvm::isa<emitc::ArrayType>(elemType)) {
83  return false;
84  }
85  return isSupportedEmitCType(elemType);
86  }
87  if (auto tupleType = llvm::dyn_cast<TupleType>(type)) {
88  return llvm::all_of(tupleType.getTypes(), [](Type type) {
89  return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
90  });
91  }
92  return false;
93 }
94 
96  if (auto intType = llvm::dyn_cast<IntegerType>(type)) {
97  switch (intType.getWidth()) {
98  case 1:
99  case 8:
100  case 16:
101  case 32:
102  case 64:
103  return true;
104  default:
105  return false;
106  }
107  }
108  return false;
109 }
110 
112  return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
114 }
115 
117  if (auto floatType = llvm::dyn_cast<FloatType>(type)) {
118  switch (floatType.getWidth()) {
119  case 16: {
120  if (llvm::isa<Float16Type, BFloat16Type>(type))
121  return true;
122  return false;
123  }
124  case 32:
125  case 64:
126  return true;
127  default:
128  return false;
129  }
130  }
131  return false;
132 }
133 
135  return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
136  type);
137 }
138 
139 /// Check that the type of the initial value is compatible with the operations
140 /// result type.
141 static LogicalResult verifyInitializationAttribute(Operation *op,
142  Attribute value) {
143  assert(op->getNumResults() == 1 && "operation must have 1 result");
144 
145  if (llvm::isa<emitc::OpaqueAttr>(value))
146  return success();
147 
148  if (llvm::isa<StringAttr>(value))
149  return op->emitOpError()
150  << "string attributes are not supported, use #emitc.opaque instead";
151 
152  Type resultType = op->getResult(0).getType();
153  if (auto lType = dyn_cast<LValueType>(resultType))
154  resultType = lType.getValueType();
155  Type attrType = cast<TypedAttr>(value).getType();
156 
157  if (isPointerWideType(resultType) && attrType.isIndex())
158  return success();
159 
160  if (resultType != attrType)
161  return op->emitOpError()
162  << "requires attribute to either be an #emitc.opaque attribute or "
163  "it's type ("
164  << attrType << ") to match the op's result type (" << resultType
165  << ")";
166 
167  return success();
168 }
169 
170 //===----------------------------------------------------------------------===//
171 // AddOp
172 //===----------------------------------------------------------------------===//
173 
174 LogicalResult AddOp::verify() {
175  Type lhsType = getLhs().getType();
176  Type rhsType = getRhs().getType();
177 
178  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
179  return emitOpError("requires that at most one operand is a pointer");
180 
181  if ((isa<emitc::PointerType>(lhsType) &&
182  !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
183  (isa<emitc::PointerType>(rhsType) &&
184  !isa<IntegerType, emitc::OpaqueType>(lhsType)))
185  return emitOpError("requires that one operand is an integer or of opaque "
186  "type if the other is a pointer");
187 
188  return success();
189 }
190 
191 //===----------------------------------------------------------------------===//
192 // ApplyOp
193 //===----------------------------------------------------------------------===//
194 
195 LogicalResult ApplyOp::verify() {
196  StringRef applicableOperatorStr = getApplicableOperator();
197 
198  // Applicable operator must not be empty.
199  if (applicableOperatorStr.empty())
200  return emitOpError("applicable operator must not be empty");
201 
202  // Only `*` and `&` are supported.
203  if (applicableOperatorStr != "&" && applicableOperatorStr != "*")
204  return emitOpError("applicable operator is illegal");
205 
206  Type operandType = getOperand().getType();
207  Type resultType = getResult().getType();
208  if (applicableOperatorStr == "&") {
209  if (!llvm::isa<emitc::LValueType>(operandType))
210  return emitOpError("operand type must be an lvalue when applying `&`");
211  if (!llvm::isa<emitc::PointerType>(resultType))
212  return emitOpError("result type must be a pointer when applying `&`");
213  } else {
214  if (!llvm::isa<emitc::PointerType>(operandType))
215  return emitOpError("operand type must be a pointer when applying `*`");
216  }
217 
218  return success();
219 }
220 
221 //===----------------------------------------------------------------------===//
222 // AssignOp
223 //===----------------------------------------------------------------------===//
224 
225 /// The assign op requires that the assigned value's type matches the
226 /// assigned-to variable type.
227 LogicalResult emitc::AssignOp::verify() {
228  TypedValue<emitc::LValueType> variable = getVar();
229 
230  if (!variable.getDefiningOp())
231  return emitOpError() << "cannot assign to block argument";
232 
233  Type valueType = getValue().getType();
234  Type variableType = variable.getType().getValueType();
235  if (variableType != valueType)
236  return emitOpError() << "requires value's type (" << valueType
237  << ") to match variable's type (" << variableType
238  << ")\n variable: " << variable
239  << "\n value: " << getValue() << "\n";
240  return success();
241 }
242 
243 //===----------------------------------------------------------------------===//
244 // CastOp
245 //===----------------------------------------------------------------------===//
246 
247 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
248  Type input = inputs.front(), output = outputs.front();
249 
250  return (
252  emitc::isSupportedFloatType(input) || isa<emitc::PointerType>(input)) &&
254  emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
255 }
256 
257 //===----------------------------------------------------------------------===//
258 // CallOpaqueOp
259 //===----------------------------------------------------------------------===//
260 
261 LogicalResult emitc::CallOpaqueOp::verify() {
262  // Callee must not be empty.
263  if (getCallee().empty())
264  return emitOpError("callee must not be empty");
265 
266  if (std::optional<ArrayAttr> argsAttr = getArgs()) {
267  for (Attribute arg : *argsAttr) {
268  auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
269  if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
270  int64_t index = intAttr.getInt();
271  // Args with elements of type index must be in range
272  // [0..operands.size).
273  if ((index < 0) || (index >= static_cast<int64_t>(getNumOperands())))
274  return emitOpError("index argument is out of range");
275 
276  // Args with elements of type ArrayAttr must have a type.
277  } else if (llvm::isa<ArrayAttr>(
278  arg) /*&& llvm::isa<NoneType>(arg.getType())*/) {
279  // FIXME: Array attributes never have types
280  return emitOpError("array argument has no type");
281  }
282  }
283  }
284 
285  if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
286  for (Attribute tArg : *templateArgsAttr) {
287  if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
288  return emitOpError("template argument has invalid type");
289  }
290  }
291 
292  if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
293  return emitOpError() << "cannot return array type";
294  }
295 
296  return success();
297 }
298 
299 //===----------------------------------------------------------------------===//
300 // ConstantOp
301 //===----------------------------------------------------------------------===//
302 
303 LogicalResult emitc::ConstantOp::verify() {
304  Attribute value = getValueAttr();
305  if (failed(verifyInitializationAttribute(getOperation(), value)))
306  return failure();
307  if (auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
308  if (opaqueValue.getValue().empty())
309  return emitOpError() << "value must not be empty";
310  }
311  return success();
312 }
313 
314 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { return getValue(); }
315 
316 //===----------------------------------------------------------------------===//
317 // ExpressionOp
318 //===----------------------------------------------------------------------===//
319 
320 Operation *ExpressionOp::getRootOp() {
321  auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
322  Value yieldedValue = yieldOp.getResult();
323  Operation *rootOp = yieldedValue.getDefiningOp();
324  assert(rootOp && "Yielded value not defined within expression");
325  return rootOp;
326 }
327 
328 LogicalResult ExpressionOp::verify() {
329  Type resultType = getResult().getType();
330  Region &region = getRegion();
331 
332  Block &body = region.front();
333 
334  if (!body.mightHaveTerminator())
335  return emitOpError("must yield a value at termination");
336 
337  auto yield = cast<YieldOp>(body.getTerminator());
338  Value yieldResult = yield.getResult();
339 
340  if (!yieldResult)
341  return emitOpError("must yield a value at termination");
342 
343  Type yieldType = yieldResult.getType();
344 
345  if (resultType != yieldType)
346  return emitOpError("requires yielded type to match return type");
347 
348  for (Operation &op : region.front().without_terminator()) {
350  return emitOpError("contains an unsupported operation");
351  if (op.getNumResults() != 1)
352  return emitOpError("requires exactly one result for each operation");
353  if (!op.getResult(0).hasOneUse())
354  return emitOpError("requires exactly one use for each operation");
355  }
356 
357  return success();
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // ForOp
362 //===----------------------------------------------------------------------===//
363 
364 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
365  Value ub, Value step, BodyBuilderFn bodyBuilder) {
366  OpBuilder::InsertionGuard g(builder);
367  result.addOperands({lb, ub, step});
368  Type t = lb.getType();
369  Region *bodyRegion = result.addRegion();
370  Block *bodyBlock = builder.createBlock(bodyRegion);
371  bodyBlock->addArgument(t, result.location);
372 
373  // Create the default terminator if the builder is not provided.
374  if (!bodyBuilder) {
375  ForOp::ensureTerminator(*bodyRegion, builder, result.location);
376  } else {
377  OpBuilder::InsertionGuard guard(builder);
378  builder.setInsertionPointToStart(bodyBlock);
379  bodyBuilder(builder, result.location, bodyBlock->getArgument(0));
380  }
381 }
382 
383 void ForOp::getCanonicalizationPatterns(RewritePatternSet &, MLIRContext *) {}
384 
385 ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
386  Builder &builder = parser.getBuilder();
387  Type type;
388 
389  OpAsmParser::Argument inductionVariable;
390  OpAsmParser::UnresolvedOperand lb, ub, step;
391 
392  // Parse the induction variable followed by '='.
393  if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
394  // Parse loop bounds.
395  parser.parseOperand(lb) || parser.parseKeyword("to") ||
396  parser.parseOperand(ub) || parser.parseKeyword("step") ||
397  parser.parseOperand(step))
398  return failure();
399 
400  // Parse the optional initial iteration arguments.
403  regionArgs.push_back(inductionVariable);
404 
405  // Parse optional type, else assume Index.
406  if (parser.parseOptionalColon())
407  type = builder.getIndexType();
408  else if (parser.parseType(type))
409  return failure();
410 
411  // Resolve input operands.
412  regionArgs.front().type = type;
413  if (parser.resolveOperand(lb, type, result.operands) ||
414  parser.resolveOperand(ub, type, result.operands) ||
415  parser.resolveOperand(step, type, result.operands))
416  return failure();
417 
418  // Parse the body region.
419  Region *body = result.addRegion();
420  if (parser.parseRegion(*body, regionArgs))
421  return failure();
422 
423  ForOp::ensureTerminator(*body, builder, result.location);
424 
425  // Parse the optional attribute list.
426  if (parser.parseOptionalAttrDict(result.attributes))
427  return failure();
428 
429  return success();
430 }
431 
432 void ForOp::print(OpAsmPrinter &p) {
433  p << " " << getInductionVar() << " = " << getLowerBound() << " to "
434  << getUpperBound() << " step " << getStep();
435 
436  p << ' ';
437  if (Type t = getInductionVar().getType(); !t.isIndex())
438  p << " : " << t << ' ';
439  p.printRegion(getRegion(),
440  /*printEntryBlockArgs=*/false,
441  /*printBlockTerminators=*/false);
442  p.printOptionalAttrDict((*this)->getAttrs());
443 }
444 
445 LogicalResult ForOp::verifyRegions() {
446  // Check that the body defines as single block argument for the induction
447  // variable.
448  if (getInductionVar().getType() != getLowerBound().getType())
449  return emitOpError(
450  "expected induction variable to be same type as bounds and step");
451 
452  return success();
453 }
454 
455 //===----------------------------------------------------------------------===//
456 // CallOp
457 //===----------------------------------------------------------------------===//
458 
459 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
460  // Check that the callee attribute was specified.
461  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
462  if (!fnAttr)
463  return emitOpError("requires a 'callee' symbol reference attribute");
464  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
465  if (!fn)
466  return emitOpError() << "'" << fnAttr.getValue()
467  << "' does not reference a valid function";
468 
469  // Verify that the operand and result types match the callee.
470  auto fnType = fn.getFunctionType();
471  if (fnType.getNumInputs() != getNumOperands())
472  return emitOpError("incorrect number of operands for callee");
473 
474  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
475  if (getOperand(i).getType() != fnType.getInput(i))
476  return emitOpError("operand type mismatch: expected operand type ")
477  << fnType.getInput(i) << ", but provided "
478  << getOperand(i).getType() << " for operand number " << i;
479 
480  if (fnType.getNumResults() != getNumResults())
481  return emitOpError("incorrect number of results for callee");
482 
483  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
484  if (getResult(i).getType() != fnType.getResult(i)) {
485  auto diag = emitOpError("result type mismatch at index ") << i;
486  diag.attachNote() << " op result types: " << getResultTypes();
487  diag.attachNote() << "function result types: " << fnType.getResults();
488  return diag;
489  }
490 
491  return success();
492 }
493 
494 FunctionType CallOp::getCalleeType() {
495  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
496 }
497 
498 //===----------------------------------------------------------------------===//
499 // DeclareFuncOp
500 //===----------------------------------------------------------------------===//
501 
502 LogicalResult
503 DeclareFuncOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
504  // Check that the sym_name attribute was specified.
505  auto fnAttr = getSymNameAttr();
506  if (!fnAttr)
507  return emitOpError("requires a 'sym_name' symbol reference attribute");
508  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
509  if (!fn)
510  return emitOpError() << "'" << fnAttr.getValue()
511  << "' does not reference a valid function";
512 
513  return success();
514 }
515 
516 //===----------------------------------------------------------------------===//
517 // FuncOp
518 //===----------------------------------------------------------------------===//
519 
520 void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
521  FunctionType type, ArrayRef<NamedAttribute> attrs,
522  ArrayRef<DictionaryAttr> argAttrs) {
523  state.addAttribute(SymbolTable::getSymbolAttrName(),
524  builder.getStringAttr(name));
525  state.addAttribute(getFunctionTypeAttrName(state.name), TypeAttr::get(type));
526  state.attributes.append(attrs.begin(), attrs.end());
527  state.addRegion();
528 
529  if (argAttrs.empty())
530  return;
531  assert(type.getNumInputs() == argAttrs.size());
533  builder, state, argAttrs, /*resultAttrs=*/std::nullopt,
534  getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
535 }
536 
537 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
538  auto buildFuncType =
539  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
541  std::string &) { return builder.getFunctionType(argTypes, results); };
542 
544  parser, result, /*allowVariadic=*/false,
545  getFunctionTypeAttrName(result.name), buildFuncType,
546  getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
547 }
548 
549 void FuncOp::print(OpAsmPrinter &p) {
551  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
552  getArgAttrsAttrName(), getResAttrsAttrName());
553 }
554 
555 LogicalResult FuncOp::verify() {
556  if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
557  return emitOpError("cannot have lvalue type as argument");
558  }
559 
560  if (getNumResults() > 1)
561  return emitOpError("requires zero or exactly one result, but has ")
562  << getNumResults();
563 
564  if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
565  return emitOpError("cannot return array type");
566 
567  return success();
568 }
569 
570 //===----------------------------------------------------------------------===//
571 // ReturnOp
572 //===----------------------------------------------------------------------===//
573 
574 LogicalResult ReturnOp::verify() {
575  auto function = cast<FuncOp>((*this)->getParentOp());
576 
577  // The operand number and types must match the function signature.
578  if (getNumOperands() != function.getNumResults())
579  return emitOpError("has ")
580  << getNumOperands() << " operands, but enclosing function (@"
581  << function.getName() << ") returns " << function.getNumResults();
582 
583  if (function.getNumResults() == 1)
584  if (getOperand().getType() != function.getResultTypes()[0])
585  return emitError() << "type of the return operand ("
586  << getOperand().getType()
587  << ") doesn't match function result type ("
588  << function.getResultTypes()[0] << ")"
589  << " in function @" << function.getName();
590  return success();
591 }
592 
593 //===----------------------------------------------------------------------===//
594 // IfOp
595 //===----------------------------------------------------------------------===//
596 
597 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
598  bool addThenBlock, bool addElseBlock) {
599  assert((!addElseBlock || addThenBlock) &&
600  "must not create else block w/o then block");
601  result.addOperands(cond);
602 
603  // Add regions and blocks.
604  OpBuilder::InsertionGuard guard(builder);
605  Region *thenRegion = result.addRegion();
606  if (addThenBlock)
607  builder.createBlock(thenRegion);
608  Region *elseRegion = result.addRegion();
609  if (addElseBlock)
610  builder.createBlock(elseRegion);
611 }
612 
613 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
614  bool withElseRegion) {
615  result.addOperands(cond);
616 
617  // Build then region.
618  OpBuilder::InsertionGuard guard(builder);
619  Region *thenRegion = result.addRegion();
620  builder.createBlock(thenRegion);
621 
622  // Build else region.
623  Region *elseRegion = result.addRegion();
624  if (withElseRegion) {
625  builder.createBlock(elseRegion);
626  }
627 }
628 
629 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
630  function_ref<void(OpBuilder &, Location)> thenBuilder,
631  function_ref<void(OpBuilder &, Location)> elseBuilder) {
632  assert(thenBuilder && "the builder callback for 'then' must be present");
633  result.addOperands(cond);
634 
635  // Build then region.
636  OpBuilder::InsertionGuard guard(builder);
637  Region *thenRegion = result.addRegion();
638  builder.createBlock(thenRegion);
639  thenBuilder(builder, result.location);
640 
641  // Build else region.
642  Region *elseRegion = result.addRegion();
643  if (elseBuilder) {
644  builder.createBlock(elseRegion);
645  elseBuilder(builder, result.location);
646  }
647 }
648 
649 ParseResult IfOp::parse(OpAsmParser &parser, OperationState &result) {
650  // Create the regions for 'then'.
651  result.regions.reserve(2);
652  Region *thenRegion = result.addRegion();
653  Region *elseRegion = result.addRegion();
654 
655  Builder &builder = parser.getBuilder();
657  Type i1Type = builder.getIntegerType(1);
658  if (parser.parseOperand(cond) ||
659  parser.resolveOperand(cond, i1Type, result.operands))
660  return failure();
661  // Parse the 'then' region.
662  if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
663  return failure();
664  IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
665 
666  // If we find an 'else' keyword then parse the 'else' region.
667  if (!parser.parseOptionalKeyword("else")) {
668  if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
669  return failure();
670  IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
671  }
672 
673  // Parse the optional attribute list.
674  if (parser.parseOptionalAttrDict(result.attributes))
675  return failure();
676  return success();
677 }
678 
679 void IfOp::print(OpAsmPrinter &p) {
680  bool printBlockTerminators = false;
681 
682  p << " " << getCondition();
683  p << ' ';
684  p.printRegion(getThenRegion(),
685  /*printEntryBlockArgs=*/false,
686  /*printBlockTerminators=*/printBlockTerminators);
687 
688  // Print the 'else' regions if it exists and has a block.
689  Region &elseRegion = getElseRegion();
690  if (!elseRegion.empty()) {
691  p << " else ";
692  p.printRegion(elseRegion,
693  /*printEntryBlockArgs=*/false,
694  /*printBlockTerminators=*/printBlockTerminators);
695  }
696 
697  p.printOptionalAttrDict((*this)->getAttrs());
698 }
699 
700 /// Given the region at `index`, or the parent operation if `index` is None,
701 /// return the successor regions. These are the regions that may be selected
702 /// during the flow of control. `operands` is a set of optional attributes that
703 /// correspond to a constant value for each operand, or null if that operand is
704 /// not a constant.
705 void IfOp::getSuccessorRegions(RegionBranchPoint point,
707  // The `then` and the `else` region branch back to the parent operation.
708  if (!point.isParent()) {
709  regions.push_back(RegionSuccessor());
710  return;
711  }
712 
713  regions.push_back(RegionSuccessor(&getThenRegion()));
714 
715  // Don't consider the else region if it is empty.
716  Region *elseRegion = &this->getElseRegion();
717  if (elseRegion->empty())
718  regions.push_back(RegionSuccessor());
719  else
720  regions.push_back(RegionSuccessor(elseRegion));
721 }
722 
723 void IfOp::getEntrySuccessorRegions(ArrayRef<Attribute> operands,
725  FoldAdaptor adaptor(operands, *this);
726  auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
727  if (!boolAttr || boolAttr.getValue())
728  regions.emplace_back(&getThenRegion());
729 
730  // If the else region is empty, execution continues after the parent op.
731  if (!boolAttr || !boolAttr.getValue()) {
732  if (!getElseRegion().empty())
733  regions.emplace_back(&getElseRegion());
734  else
735  regions.emplace_back();
736  }
737 }
738 
739 void IfOp::getRegionInvocationBounds(
740  ArrayRef<Attribute> operands,
741  SmallVectorImpl<InvocationBounds> &invocationBounds) {
742  if (auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
743  // If the condition is known, then one region is known to be executed once
744  // and the other zero times.
745  invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
746  invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
747  } else {
748  // Non-constant condition. Each region may be executed 0 or 1 times.
749  invocationBounds.assign(2, {0, 1});
750  }
751 }
752 
753 //===----------------------------------------------------------------------===//
754 // IncludeOp
755 //===----------------------------------------------------------------------===//
756 
758  bool standardInclude = getIsStandardInclude();
759 
760  p << " ";
761  if (standardInclude)
762  p << "<";
763  p << "\"" << getInclude() << "\"";
764  if (standardInclude)
765  p << ">";
766 }
767 
768 ParseResult IncludeOp::parse(OpAsmParser &parser, OperationState &result) {
769  bool standardInclude = !parser.parseOptionalLess();
770 
771  StringAttr include;
772  OptionalParseResult includeParseResult =
773  parser.parseOptionalAttribute(include, "include", result.attributes);
774  if (!includeParseResult.has_value())
775  return parser.emitError(parser.getNameLoc()) << "expected string attribute";
776 
777  if (standardInclude && parser.parseOptionalGreater())
778  return parser.emitError(parser.getNameLoc())
779  << "expected trailing '>' for standard include";
780 
781  if (standardInclude)
782  result.addAttribute("is_standard_include",
783  UnitAttr::get(parser.getContext()));
784 
785  return success();
786 }
787 
788 //===----------------------------------------------------------------------===//
789 // LiteralOp
790 //===----------------------------------------------------------------------===//
791 
792 /// The literal op requires a non-empty value.
793 LogicalResult emitc::LiteralOp::verify() {
794  if (getValue().empty())
795  return emitOpError() << "value must not be empty";
796  return success();
797 }
798 //===----------------------------------------------------------------------===//
799 // SubOp
800 //===----------------------------------------------------------------------===//
801 
802 LogicalResult SubOp::verify() {
803  Type lhsType = getLhs().getType();
804  Type rhsType = getRhs().getType();
805  Type resultType = getResult().getType();
806 
807  if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
808  return emitOpError("rhs can only be a pointer if lhs is a pointer");
809 
810  if (isa<emitc::PointerType>(lhsType) &&
811  !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
812  return emitOpError("requires that rhs is an integer, pointer or of opaque "
813  "type if lhs is a pointer");
814 
815  if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
816  !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
817  return emitOpError("requires that the result is an integer, ptrdiff_t or "
818  "of opaque type if lhs and rhs are pointers");
819  return success();
820 }
821 
822 //===----------------------------------------------------------------------===//
823 // VariableOp
824 //===----------------------------------------------------------------------===//
825 
826 LogicalResult emitc::VariableOp::verify() {
827  return verifyInitializationAttribute(getOperation(), getValueAttr());
828 }
829 
830 //===----------------------------------------------------------------------===//
831 // YieldOp
832 //===----------------------------------------------------------------------===//
833 
834 LogicalResult emitc::YieldOp::verify() {
835  Value result = getResult();
836  Operation *containingOp = getOperation()->getParentOp();
837 
838  if (result && containingOp->getNumResults() != 1)
839  return emitOpError() << "yields a value not returned by parent";
840 
841  if (!result && containingOp->getNumResults() != 0)
842  return emitOpError() << "does not yield a value to be returned by parent";
843 
844  return success();
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // SubscriptOp
849 //===----------------------------------------------------------------------===//
850 
851 LogicalResult emitc::SubscriptOp::verify() {
852  // Checks for array operand.
853  if (auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().getType())) {
854  // Check number of indices.
855  if (getIndices().size() != (size_t)arrayType.getRank()) {
856  return emitOpError() << "on array operand requires number of indices ("
857  << getIndices().size()
858  << ") to match the rank of the array type ("
859  << arrayType.getRank() << ")";
860  }
861  // Check types of index operands.
862  for (unsigned i = 0, e = getIndices().size(); i != e; ++i) {
863  Type type = getIndices()[i].getType();
864  if (!isIntegerIndexOrOpaqueType(type)) {
865  return emitOpError() << "on array operand requires index operand " << i
866  << " to be integer-like, but got " << type;
867  }
868  }
869  // Check element type.
870  Type elementType = arrayType.getElementType();
871  Type resultType = getType().getValueType();
872  if (elementType != resultType) {
873  return emitOpError() << "on array operand requires element type ("
874  << elementType << ") and result type (" << resultType
875  << ") to match";
876  }
877  return success();
878  }
879 
880  // Checks for pointer operand.
881  if (auto pointerType =
882  llvm::dyn_cast<emitc::PointerType>(getValue().getType())) {
883  // Check number of indices.
884  if (getIndices().size() != 1) {
885  return emitOpError()
886  << "on pointer operand requires one index operand, but got "
887  << getIndices().size();
888  }
889  // Check types of index operand.
890  Type type = getIndices()[0].getType();
891  if (!isIntegerIndexOrOpaqueType(type)) {
892  return emitOpError() << "on pointer operand requires index operand to be "
893  "integer-like, but got "
894  << type;
895  }
896  // Check pointee type.
897  Type pointeeType = pointerType.getPointee();
898  Type resultType = getType().getValueType();
899  if (pointeeType != resultType) {
900  return emitOpError() << "on pointer operand requires pointee type ("
901  << pointeeType << ") and result type (" << resultType
902  << ") to match";
903  }
904  return success();
905  }
906 
907  // The operand has opaque type, so we can't assume anything about the number
908  // or types of index operands.
909  return success();
910 }
911 
912 //===----------------------------------------------------------------------===//
913 // EmitC Enums
914 //===----------------------------------------------------------------------===//
915 
916 #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
917 
918 //===----------------------------------------------------------------------===//
919 // EmitC Attributes
920 //===----------------------------------------------------------------------===//
921 
922 #define GET_ATTRDEF_CLASSES
923 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
924 
925 //===----------------------------------------------------------------------===//
926 // EmitC Types
927 //===----------------------------------------------------------------------===//
928 
929 #define GET_TYPEDEF_CLASSES
930 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
931 
932 //===----------------------------------------------------------------------===//
933 // ArrayType
934 //===----------------------------------------------------------------------===//
935 
937  if (parser.parseLess())
938  return Type();
939 
940  SmallVector<int64_t, 4> dimensions;
941  if (parser.parseDimensionList(dimensions, /*allowDynamic=*/false,
942  /*withTrailingX=*/true))
943  return Type();
944  // Parse the element type.
945  auto typeLoc = parser.getCurrentLocation();
946  Type elementType;
947  if (parser.parseType(elementType))
948  return Type();
949 
950  // Check that array is formed from allowed types.
951  if (!isValidElementType(elementType))
952  return parser.emitError(typeLoc, "invalid array element type"), Type();
953  if (parser.parseGreater())
954  return Type();
955  return parser.getChecked<ArrayType>(dimensions, elementType);
956 }
957 
958 void emitc::ArrayType::print(AsmPrinter &printer) const {
959  printer << "<";
960  for (int64_t dim : getShape()) {
961  printer << dim << 'x';
962  }
963  printer.printType(getElementType());
964  printer << ">";
965 }
966 
967 LogicalResult emitc::ArrayType::verify(
969  ::llvm::ArrayRef<int64_t> shape, Type elementType) {
970  if (shape.empty())
971  return emitError() << "shape must not be empty";
972 
973  for (int64_t dim : shape) {
974  if (dim <= 0)
975  return emitError() << "dimensions must have positive size";
976  }
977 
978  if (!elementType)
979  return emitError() << "element type must not be none";
980 
981  if (!isValidElementType(elementType))
982  return emitError() << "invalid array element type";
983 
984  return success();
985 }
986 
987 emitc::ArrayType
988 emitc::ArrayType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
989  Type elementType) const {
990  if (!shape)
991  return emitc::ArrayType::get(getShape(), elementType);
992  return emitc::ArrayType::get(*shape, elementType);
993 }
994 
995 //===----------------------------------------------------------------------===//
996 // LValueType
997 //===----------------------------------------------------------------------===//
998 
999 LogicalResult mlir::emitc::LValueType::verify(
1001  mlir::Type value) {
1002  // Check that the wrapped type is valid. This especially forbids nested lvalue
1003  // types.
1004  if (!isSupportedEmitCType(value))
1005  return emitError()
1006  << "!emitc.lvalue must wrap supported emitc type, but got " << value;
1007 
1008  if (llvm::isa<emitc::ArrayType>(value))
1009  return emitError() << "!emitc.lvalue cannot wrap !emitc.array type";
1010 
1011  return success();
1012 }
1013 
1014 //===----------------------------------------------------------------------===//
1015 // OpaqueType
1016 //===----------------------------------------------------------------------===//
1017 
1018 LogicalResult mlir::emitc::OpaqueType::verify(
1020  llvm::StringRef value) {
1021  if (value.empty()) {
1022  return emitError() << "expected non empty string in !emitc.opaque type";
1023  }
1024  if (value.back() == '*') {
1025  return emitError() << "pointer not allowed as outer type with "
1026  "!emitc.opaque, use !emitc.ptr instead";
1027  }
1028  return success();
1029 }
1030 
1031 //===----------------------------------------------------------------------===//
1032 // PointerType
1033 //===----------------------------------------------------------------------===//
1034 
1035 LogicalResult mlir::emitc::PointerType::verify(
1037  if (llvm::isa<emitc::LValueType>(value))
1038  return emitError() << "pointers to lvalues are not allowed";
1039 
1040  return success();
1041 }
1042 
1043 //===----------------------------------------------------------------------===//
1044 // GlobalOp
1045 //===----------------------------------------------------------------------===//
1047  TypeAttr type,
1048  Attribute initialValue) {
1049  p << type;
1050  if (initialValue) {
1051  p << " = ";
1052  p.printAttributeWithoutType(initialValue);
1053  }
1054 }
1055 
1057  if (auto array = llvm::dyn_cast<ArrayType>(type))
1058  return RankedTensorType::get(array.getShape(), array.getElementType());
1059  return type;
1060 }
1061 
1062 static ParseResult
1064  Attribute &initialValue) {
1065  Type type;
1066  if (parser.parseType(type))
1067  return failure();
1068 
1069  typeAttr = TypeAttr::get(type);
1070 
1071  if (parser.parseOptionalEqual())
1072  return success();
1073 
1074  if (parser.parseAttribute(initialValue, getInitializerTypeForGlobal(type)))
1075  return failure();
1076 
1077  if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1078  initialValue))
1079  return parser.emitError(parser.getNameLoc())
1080  << "initial value should be a integer, float, elements or opaque "
1081  "attribute";
1082  return success();
1083 }
1084 
1085 LogicalResult GlobalOp::verify() {
1086  if (!isSupportedEmitCType(getType())) {
1087  return emitOpError("expected valid emitc type");
1088  }
1089  if (getInitialValue().has_value()) {
1090  Attribute initValue = getInitialValue().value();
1091  // Check that the type of the initial value is compatible with the type of
1092  // the global variable.
1093  if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1094  auto arrayType = llvm::dyn_cast<ArrayType>(getType());
1095  if (!arrayType)
1096  return emitOpError("expected array type, but got ") << getType();
1097 
1098  Type initType = elementsAttr.getType();
1099  Type tensorType = getInitializerTypeForGlobal(getType());
1100  if (initType != tensorType) {
1101  return emitOpError("initial value expected to be of type ")
1102  << getType() << ", but was of type " << initType;
1103  }
1104  } else if (auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1105  if (intAttr.getType() != getType()) {
1106  return emitOpError("initial value expected to be of type ")
1107  << getType() << ", but was of type " << intAttr.getType();
1108  }
1109  } else if (auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1110  if (floatAttr.getType() != getType()) {
1111  return emitOpError("initial value expected to be of type ")
1112  << getType() << ", but was of type " << floatAttr.getType();
1113  }
1114  } else if (!isa<emitc::OpaqueAttr>(initValue)) {
1115  return emitOpError("initial value should be a integer, float, elements "
1116  "or opaque attribute, but got ")
1117  << initValue;
1118  }
1119  }
1120  if (getStaticSpecifier() && getExternSpecifier()) {
1121  return emitOpError("cannot have both static and extern specifiers");
1122  }
1123  return success();
1124 }
1125 
1126 //===----------------------------------------------------------------------===//
1127 // GetGlobalOp
1128 //===----------------------------------------------------------------------===//
1129 
1130 LogicalResult
1131 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1132  // Verify that the type matches the type of the global variable.
1133  auto global =
1134  symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1135  if (!global)
1136  return emitOpError("'")
1137  << getName() << "' does not reference a valid emitc.global";
1138 
1139  Type resultType = getResult().getType();
1140  Type globalType = global.getType();
1141 
1142  // global has array type
1143  if (llvm::isa<ArrayType>(globalType)) {
1144  if (globalType != resultType)
1145  return emitOpError("on array type expects result type ")
1146  << resultType << " to match type " << globalType
1147  << " of the global @" << getName();
1148  return success();
1149  }
1150 
1151  // global has non-array type
1152  auto lvalueType = dyn_cast<LValueType>(resultType);
1153  if (!lvalueType || lvalueType.getValueType() != globalType)
1154  return emitOpError("on non-array type expects result inner type ")
1155  << lvalueType.getValueType() << " to match type " << globalType
1156  << " of the global @" << getName();
1157  return success();
1158 }
1159 
1160 //===----------------------------------------------------------------------===//
1161 // SwitchOp
1162 //===----------------------------------------------------------------------===//
1163 
1164 /// Parse the case regions and values.
1165 static ParseResult
1167  SmallVectorImpl<std::unique_ptr<Region>> &caseRegions) {
1168  SmallVector<int64_t> caseValues;
1169  while (succeeded(parser.parseOptionalKeyword("case"))) {
1170  int64_t value;
1171  Region &region = *caseRegions.emplace_back(std::make_unique<Region>());
1172  if (parser.parseInteger(value) ||
1173  parser.parseRegion(region, /*arguments=*/{}))
1174  return failure();
1175  caseValues.push_back(value);
1176  }
1177  cases = parser.getBuilder().getDenseI64ArrayAttr(caseValues);
1178  return success();
1179 }
1180 
1181 /// Print the case regions and values.
1183  DenseI64ArrayAttr cases, RegionRange caseRegions) {
1184  for (auto [value, region] : llvm::zip(cases.asArrayRef(), caseRegions)) {
1185  p.printNewline();
1186  p << "case " << value << ' ';
1187  p.printRegion(*region, /*printEntryBlockArgs=*/false);
1188  }
1189 }
1190 
1191 static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region,
1192  const Twine &name) {
1193  auto yield = dyn_cast<emitc::YieldOp>(region.front().back());
1194  if (!yield)
1195  return op.emitOpError("expected region to end with emitc.yield, but got ")
1196  << region.front().back().getName();
1197 
1198  if (yield.getNumOperands() != 0) {
1199  return (op.emitOpError("expected each region to return ")
1200  << "0 values, but " << name << " returns "
1201  << yield.getNumOperands())
1202  .attachNote(yield.getLoc())
1203  << "see yield operation here";
1204  }
1205 
1206  return success();
1207 }
1208 
1209 LogicalResult emitc::SwitchOp::verify() {
1210  if (!isIntegerIndexOrOpaqueType(getArg().getType()))
1211  return emitOpError("unsupported type ") << getArg().getType();
1212 
1213  if (getCases().size() != getCaseRegions().size()) {
1214  return emitOpError("has ")
1215  << getCaseRegions().size() << " case regions but "
1216  << getCases().size() << " case values";
1217  }
1218 
1219  DenseSet<int64_t> valueSet;
1220  for (int64_t value : getCases())
1221  if (!valueSet.insert(value).second)
1222  return emitOpError("has duplicate case value: ") << value;
1223 
1224  if (failed(verifyRegion(*this, getDefaultRegion(), "default region")))
1225  return failure();
1226 
1227  for (auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1228  if (failed(verifyRegion(*this, caseRegion, "case region #" + Twine(idx))))
1229  return failure();
1230 
1231  return success();
1232 }
1233 
1234 unsigned emitc::SwitchOp::getNumCases() { return getCases().size(); }
1235 
1236 Block &emitc::SwitchOp::getDefaultBlock() { return getDefaultRegion().front(); }
1237 
1238 Block &emitc::SwitchOp::getCaseBlock(unsigned idx) {
1239  assert(idx < getNumCases() && "case index out-of-bounds");
1240  return getCaseRegions()[idx].front();
1241 }
1242 
1243 void SwitchOp::getSuccessorRegions(
1245  llvm::copy(getRegions(), std::back_inserter(successors));
1246 }
1247 
1248 void SwitchOp::getEntrySuccessorRegions(
1249  ArrayRef<Attribute> operands,
1250  SmallVectorImpl<RegionSuccessor> &successors) {
1251  FoldAdaptor adaptor(operands, *this);
1252 
1253  // If a constant was not provided, all regions are possible successors.
1254  auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1255  if (!arg) {
1256  llvm::copy(getRegions(), std::back_inserter(successors));
1257  return;
1258  }
1259 
1260  // Otherwise, try to find a case with a matching value. If not, the
1261  // default region is the only successor.
1262  for (auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1263  if (caseValue == arg.getInt()) {
1264  successors.emplace_back(&caseRegion);
1265  return;
1266  }
1267  }
1268  successors.emplace_back(&getDefaultRegion());
1269 }
1270 
1271 void SwitchOp::getRegionInvocationBounds(
1273  auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1274  if (!operandValue) {
1275  // All regions are invoked at most once.
1276  bounds.append(getNumRegions(), InvocationBounds(/*lb=*/0, /*ub=*/1));
1277  return;
1278  }
1279 
1280  unsigned liveIndex = getNumRegions() - 1;
1281  const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1282 
1283  liveIndex = iteratorToInt != getCases().end()
1284  ? std::distance(getCases().begin(), iteratorToInt)
1285  : liveIndex;
1286 
1287  for (unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1288  ++regIndex)
1289  bounds.emplace_back(/*lb=*/0, /*ub=*/regIndex == liveIndex);
1290 }
1291 
1292 //===----------------------------------------------------------------------===//
1293 // TableGen'd op method definitions
1294 //===----------------------------------------------------------------------===//
1295 
1296 #define GET_OP_CLASSES
1297 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
Definition: AffineOps.cpp:720
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
Definition: AffineOps.cpp:712
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
Definition: EmitC.cpp:141
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition: EmitC.cpp:1191
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
Definition: EmitC.cpp:1063
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
Definition: EmitC.cpp:1166
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
Definition: EmitC.cpp:1046
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
Definition: EmitC.cpp:1182
static Type getInitializerTypeForGlobal(Type type)
Definition: EmitC.cpp:1056
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
Definition: FoldUtils.cpp:50
static MLIRContext * getContext(OpFoldResult val)
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
Definition: SPIRVOps.cpp:215
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
Definition: AsmPrinter.cpp:73
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
Block represents an ordered list of Operations.
Definition: Block.h:31
BlockArgument getArgument(unsigned i)
Definition: Block.h:127
Operation & back()
Definition: Block.h:150
Operation * getTerminator()
Get the terminator operation of this block.
Definition: Block.cpp:243
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition: Block.cpp:152
bool mightHaveTerminator()
Check whether this block might have a terminator.
Definition: Block.cpp:249
Operation & front()
Definition: Block.h:151
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
Definition: Block.h:207
This class is a general helper class for creating context-global objects like types,...
Definition: Builders.h:50
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition: Builders.cpp:120
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:293
IndexType getIndexType()
Definition: Builders.cpp:95
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
Definition: Diagnostics.h:313
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:356
This class helps build Operations.
Definition: Builders.h:215
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:439
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition: Builders.cpp:461
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:488
This class represents a single result from folding an operation.
Definition: OpDefinition.h:268
type_range getType() const
Definition: ValueRange.cpp:30
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:745
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
This class implements Optional functionality for ParseResult.
Definition: OpDefinition.h:39
bool has_value() const
Returns true if we contain a valid ParseResult value.
Definition: OpDefinition.h:49
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition: Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition: Region.h:26
bool empty()
Definition: Region.h:60
Block & front()
Definition: Region.h:65
This class represents a collection of SymbolTables.
Definition: SymbolTable.h:283
Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition: SymbolTable.h:76
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIndex() const
Definition: Types.cpp:64
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:215
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
A named class for passing around the variadic flag.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
Definition: EmitC.cpp:57
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
Definition: EmitC.cpp:116
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
Definition: EmitC.cpp:61
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
Definition: EmitC.cpp:134
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
Definition: EmitC.cpp:111
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
Definition: EmitC.cpp:95
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition: Utils.cpp:18
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Definition: Query.cpp:20
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:498
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:426
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
NamedAttrList attributes
Region * addRegion()
Create a region that should be attached to the operation.