18 #include "llvm/ADT/STLExtras.h" 
   19 #include "llvm/ADT/SmallVector.h" 
   20 #include "llvm/ADT/TypeSwitch.h" 
   21 #include "llvm/Support/Casting.h" 
   26 #include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc" 
   32 void EmitCDialect::initialize() {
 
   35 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" 
   38 #define GET_TYPEDEF_LIST 
   39 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" 
   42 #define GET_ATTRDEF_LIST 
   43 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" 
   52   return emitc::ConstantOp::create(builder, loc, type, value);
 
   58   emitc::YieldOp::create(builder, loc);
 
   62   if (llvm::isa<emitc::OpaqueType>(type))
 
   64   if (
auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
 
   66   if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
 
   67     auto elemType = arrayType.getElementType();
 
   68     return !llvm::isa<emitc::ArrayType>(elemType) &&
 
   73   if (llvm::isa<IntegerType>(type))
 
   75   if (llvm::isa<FloatType>(type))
 
   77   if (
auto tensorType = llvm::dyn_cast<TensorType>(type)) {
 
   78     if (!tensorType.hasStaticShape()) {
 
   81     auto elemType = tensorType.getElementType();
 
   82     if (llvm::isa<emitc::ArrayType>(elemType)) {
 
   87   if (
auto tupleType = llvm::dyn_cast<TupleType>(type)) {
 
   88     return llvm::all_of(tupleType.getTypes(), [](
Type type) {
 
   89       return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
 
   96   if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
 
   97     switch (intType.getWidth()) {
 
  112   return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
 
  117   if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
 
  118     switch (floatType.getWidth()) {
 
  120       return llvm::isa<Float16Type, BFloat16Type>(type);
 
  132   return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
 
  139          isa<emitc::PointerType>(type);
 
  146   assert(op->
getNumResults() == 1 && 
"operation must have 1 result");
 
  148   if (llvm::isa<emitc::OpaqueAttr>(value))
 
  151   if (llvm::isa<StringAttr>(value))
 
  153            << 
"string attributes are not supported, use #emitc.opaque instead";
 
  156   if (
auto lType = dyn_cast<LValueType>(resultType))
 
  157     resultType = lType.getValueType();
 
  158   Type attrType = cast<TypedAttr>(value).getType();
 
  163   if (resultType != attrType)
 
  165            << 
"requires attribute to either be an #emitc.opaque attribute or " 
  167            << attrType << 
") to match the op's result type (" << resultType
 
  178 template <
class ArgType>
 
  180     StringRef toParse, ArgType fmtArgs,
 
  185   if (fmtArgs.empty()) {
 
  186     items.push_back(toParse);
 
  190   while (!toParse.empty()) {
 
  191     size_t idx = toParse.find(
'{');
 
  192     if (idx == StringRef::npos) {
 
  194       items.push_back(toParse);
 
  199       items.push_back(toParse.take_front(idx));
 
  200       toParse = toParse.drop_front(idx);
 
  203     if (toParse.size() < 2) {
 
  204       return emitError() << 
"expected '}' after unescaped '{' at end of string";
 
  207     char nextChar = toParse[1];
 
  208     if (nextChar == 
'{') {
 
  210       items.push_back(toParse.take_front(1));
 
  211       toParse = toParse.drop_front(2);
 
  214     if (nextChar == 
'}') {
 
  216       toParse = toParse.drop_front(2);
 
  221       return emitError() << 
"expected '}' after unescaped '{'";
 
  233   Type lhsType = getLhs().getType();
 
  234   Type rhsType = getRhs().getType();
 
  236   if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
 
  237     return emitOpError(
"requires that at most one operand is a pointer");
 
  239   if ((isa<emitc::PointerType>(lhsType) &&
 
  240        !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
 
  241       (isa<emitc::PointerType>(rhsType) &&
 
  242        !isa<IntegerType, emitc::OpaqueType>(lhsType)))
 
  243     return emitOpError(
"requires that one operand is an integer or of opaque " 
  244                        "type if the other is a pointer");
 
  254   StringRef applicableOperatorStr = getApplicableOperator();
 
  257   if (applicableOperatorStr.empty())
 
  258     return emitOpError(
"applicable operator must not be empty");
 
  261   if (applicableOperatorStr != 
"&" && applicableOperatorStr != 
"*")
 
  262     return emitOpError(
"applicable operator is illegal");
 
  264   Type operandType = getOperand().getType();
 
  265   Type resultType = getResult().getType();
 
  266   if (applicableOperatorStr == 
"&") {
 
  267     if (!llvm::isa<emitc::LValueType>(operandType))
 
  268       return emitOpError(
"operand type must be an lvalue when applying `&`");
 
  269     if (!llvm::isa<emitc::PointerType>(resultType))
 
  270       return emitOpError(
"result type must be a pointer when applying `&`");
 
  272     if (!llvm::isa<emitc::PointerType>(operandType))
 
  273       return emitOpError(
"operand type must be a pointer when applying `*`");
 
  288   if (!variable.getDefiningOp())
 
  289     return emitOpError() << 
"cannot assign to block argument";
 
  291   Type valueType = getValue().getType();
 
  292   Type variableType = variable.getType().getValueType();
 
  293   if (variableType != valueType)
 
  294     return emitOpError() << 
"requires value's type (" << valueType
 
  295                          << 
") to match variable's type (" << variableType
 
  296                          << 
")\n  variable: " << variable
 
  297                          << 
"\n  value: " << getValue() << 
"\n";
 
  306   Type input = inputs.front(), output = outputs.front();
 
  308   if (
auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
 
  309     if (
auto pointerType = dyn_cast<emitc::PointerType>(output)) {
 
  310       return (arrayType.getElementType() == pointerType.getPointee()) &&
 
  311              arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
 
  329   if (getCallee().empty())
 
  330     return emitOpError(
"callee must not be empty");
 
  332   if (std::optional<ArrayAttr> argsAttr = getArgs()) {
 
  334       auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
 
  335       if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
 
  336         int64_t index = intAttr.getInt();
 
  339         if ((index < 0) || (index >= 
static_cast<int64_t
>(getNumOperands())))
 
  340           return emitOpError(
"index argument is out of range");
 
  343       } 
else if (llvm::isa<ArrayAttr>(
 
  346         return emitOpError(
"array argument has no type");
 
  351   if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
 
  352     for (
Attribute tArg : *templateArgsAttr) {
 
  353       if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
 
  354         return emitOpError(
"template argument has invalid type");
 
  358   if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
 
  359     return emitOpError() << 
"cannot return array type";
 
  373   if (
auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
 
  374     if (opaqueValue.getValue().empty())
 
  375       return emitOpError() << 
"value must not be empty";
 
  380 OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) { 
return getValue(); }
 
  396                             "expected function type");
 
  397   auto fnType = llvm::dyn_cast<FunctionType>(type);
 
  400                             "expected function type");
 
  404   if (fnType.getNumResults() != 1)
 
  406                             "expected single return type");
 
  407   result.
addTypes(fnType.getResults());
 
  410   for (
auto [unresolvedOperand, operandType] :
 
  411        llvm::zip(operands, fnType.getInputs())) {
 
  413     argInfo.
ssaName = unresolvedOperand;
 
  414     argInfo.
type = operandType;
 
  415     argsInfo.push_back(argInfo);
 
  433   auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
 
  434   Value yieldedValue = yieldOp.getResult();
 
  439   Type resultType = getResult().getType();
 
  440   Region ®ion = getRegion();
 
  445     return emitOpError(
"must yield a value at termination");
 
  448   Value yieldResult = yield.getResult();
 
  451     return emitOpError(
"must yield a value at termination");
 
  456     return emitOpError(
"yielded value has no defining op");
 
  459     return emitOpError(
"yielded value not defined within expression");
 
  463   if (resultType != yieldType)
 
  464     return emitOpError(
"requires yielded type to match return type");
 
  467     auto expressionInterface = dyn_cast<emitc::CExpressionInterface>(op);
 
  468     if (!expressionInterface)
 
  469       return emitOpError(
"contains an unsupported operation");
 
  470     if (op.getNumResults() != 1)
 
  471       return emitOpError(
"requires exactly one result for each operation");
 
  472     Value result = op.getResult(0);
 
  474       return emitOpError(
"contains an unused operation");
 
  481   worklist.push_back(rootOp);
 
  482   while (!worklist.empty()) {
 
  485     if (visited.contains(op)) {
 
  488             "requires exactly one use for operations with side effects");
 
  492       if (
Operation *def = operand.getDefiningOp()) {
 
  493         worklist.push_back(def);
 
  505                   Value ub, 
Value step, BodyBuilderFn bodyBuilder) {
 
  515     ForOp::ensureTerminator(*bodyRegion, builder, result.
location);
 
  542   regionArgs.push_back(inductionVariable);
 
  551   regionArgs.front().type = type;
 
  562   ForOp::ensureTerminator(*body, builder, result.
location);
 
  572   p << 
" " << getInductionVar() << 
" = " << 
getLowerBound() << 
" to " 
  577     p << 
" : " << t << 
' ';
 
  584 LogicalResult ForOp::verifyRegions() {
 
  589         "expected induction variable to be same type as bounds and step");
 
  602     return emitOpError(
"requires a 'callee' symbol reference attribute");
 
  605     return emitOpError() << 
"'" << fnAttr.getValue()
 
  606                          << 
"' does not reference a valid function";
 
  609   auto fnType = fn.getFunctionType();
 
  610   if (fnType.getNumInputs() != getNumOperands())
 
  611     return emitOpError(
"incorrect number of operands for callee");
 
  613   for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
 
  614     if (getOperand(i).
getType() != fnType.getInput(i))
 
  615       return emitOpError(
"operand type mismatch: expected operand type ")
 
  616              << fnType.getInput(i) << 
", but provided " 
  617              << getOperand(i).getType() << 
" for operand number " << i;
 
  619   if (fnType.getNumResults() != getNumResults())
 
  620     return emitOpError(
"incorrect number of results for callee");
 
  622   for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
 
  623     if (getResult(i).
getType() != fnType.getResult(i)) {
 
  624       auto diag = emitOpError(
"result type mismatch at index ") << i;
 
  625       diag.attachNote() << 
"      op result types: " << getResultTypes();
 
  626       diag.attachNote() << 
"function result types: " << fnType.getResults();
 
  633 FunctionType CallOp::getCalleeType() {
 
  644   auto fnAttr = getSymNameAttr();
 
  646     return emitOpError(
"requires a 'sym_name' symbol reference attribute");
 
  649     return emitOpError() << 
"'" << fnAttr.getValue()
 
  650                          << 
"' does not reference a valid function";
 
  664   state.addAttribute(getFunctionTypeAttrName(state.name), 
TypeAttr::get(type));
 
  665   state.attributes.append(attrs.begin(), attrs.end());
 
  668   if (argAttrs.empty())
 
  670   assert(type.getNumInputs() == argAttrs.size());
 
  672       builder, state, argAttrs, {},
 
  673       getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
 
  683       parser, result, 
false,
 
  684       getFunctionTypeAttrName(result.
name), buildFuncType,
 
  685       getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
 
  690       p, *
this, 
false, getFunctionTypeAttrName(),
 
  691       getArgAttrsAttrName(), getResAttrsAttrName());
 
  695   if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
 
  696     return emitOpError(
"cannot have lvalue type as argument");
 
  699   if (getNumResults() > 1)
 
  700     return emitOpError(
"requires zero or exactly one result, but has ")
 
  703   if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
 
  704     return emitOpError(
"cannot return array type");
 
  714   auto function = cast<FuncOp>((*this)->getParentOp());
 
  717   if (getNumOperands() != 
function.getNumResults())
 
  718     return emitOpError(
"has ")
 
  719            << getNumOperands() << 
" operands, but enclosing function (@" 
  720            << 
function.getName() << 
") returns " << 
function.getNumResults();
 
  722   if (
function.getNumResults() == 1)
 
  723     if (getOperand().
getType() != 
function.getResultTypes()[0])
 
  724       return emitError() << 
"type of the return operand (" 
  725                          << getOperand().getType()
 
  726                          << 
") doesn't match function result type (" 
  727                          << 
function.getResultTypes()[0] << 
")" 
  728                          << 
" in function @" << 
function.getName();
 
  737                  bool addThenBlock, 
bool addElseBlock) {
 
  738   assert((!addElseBlock || addThenBlock) &&
 
  739          "must not create else block w/o then block");
 
  753                  bool withElseRegion) {
 
  763   if (withElseRegion) {
 
  771   assert(thenBuilder && 
"the builder callback for 'then' must be present");
 
  778   thenBuilder(builder, result.
location);
 
  784     elseBuilder(builder, result.
location);
 
  819   bool printBlockTerminators = 
false;
 
  821   p << 
" " << getCondition();
 
  825                 printBlockTerminators);
 
  828   Region &elseRegion = getElseRegion();
 
  829   if (!elseRegion.
empty()) {
 
  833                   printBlockTerminators);
 
  856   Region *elseRegion = &this->getElseRegion();
 
  857   if (elseRegion->
empty())
 
  866   FoldAdaptor adaptor(operands, *
this);
 
  867   auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
 
  868   if (!boolAttr || boolAttr.getValue())
 
  869     regions.emplace_back(&getThenRegion());
 
  872   if (!boolAttr || !boolAttr.getValue()) {
 
  873     if (!getElseRegion().empty())
 
  874       regions.emplace_back(&getElseRegion());
 
  876       regions.emplace_back(getOperation(), getOperation()->getResults());
 
  880 void IfOp::getRegionInvocationBounds(
 
  883   if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
 
  886     invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
 
  887     invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
 
  890     invocationBounds.assign(2, {0, 1});
 
  899   bool standardInclude = getIsStandardInclude();
 
  904   p << 
"\"" << getInclude() << 
"\"";
 
  920            << 
"expected trailing '>' for standard include";
 
  935   if (getValue().empty())
 
  936     return emitOpError() << 
"value must not be empty";
 
  944   Type lhsType = getLhs().getType();
 
  945   Type rhsType = getRhs().getType();
 
  946   Type resultType = getResult().getType();
 
  948   if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
 
  949     return emitOpError(
"rhs can only be a pointer if lhs is a pointer");
 
  951   if (isa<emitc::PointerType>(lhsType) &&
 
  952       !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
 
  953     return emitOpError(
"requires that rhs is an integer, pointer or of opaque " 
  954                        "type if lhs is a pointer");
 
  956   if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
 
  957       !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
 
  958     return emitOpError(
"requires that the result is an integer, ptrdiff_t or " 
  959                        "of opaque type if lhs and rhs are pointers");
 
  976   Value result = getResult();
 
  979   if (!isa<DoOp>(containingOp) && result && containingOp->
getNumResults() != 1)
 
  980     return emitOpError() << 
"yields a value not returned by parent";
 
  982   if (!isa<DoOp>(containingOp) && !result && containingOp->
getNumResults() != 0)
 
  983     return emitOpError() << 
"does not yield a value to be returned by parent";
 
  994   if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().
getType())) {
 
  996     if (
getIndices().size() != (
size_t)arrayType.getRank()) {
 
  997       return emitOpError() << 
"on array operand requires number of indices (" 
  999                            << 
") to match the rank of the array type (" 
 1000                            << arrayType.getRank() << 
")";
 
 1003     for (
unsigned i = 0, e = 
getIndices().size(); i != e; ++i) {
 
 1006         return emitOpError() << 
"on array operand requires index operand " << i
 
 1007                              << 
" to be integer-like, but got " << type;
 
 1011     Type elementType = arrayType.getElementType();
 
 1013     if (elementType != resultType) {
 
 1014       return emitOpError() << 
"on array operand requires element type (" 
 1015                            << elementType << 
") and result type (" << resultType
 
 1022   if (
auto pointerType =
 
 1023           llvm::dyn_cast<emitc::PointerType>(getValue().
getType())) {
 
 1026       return emitOpError()
 
 1027              << 
"on pointer operand requires one index operand, but got " 
 1033       return emitOpError() << 
"on pointer operand requires index operand to be " 
 1034                               "integer-like, but got " 
 1038     Type pointeeType = pointerType.getPointee();
 
 1040     if (pointeeType != resultType) {
 
 1041       return emitOpError() << 
"on pointer operand requires pointee type (" 
 1042                            << pointeeType << 
") and result type (" << resultType
 
 1059     return this->emitOpError();
 
 1061   FailureOr<SmallVector<ReplacementItem>> fmt =
 
 1066   size_t numPlaceholders = llvm::count_if(*fmt, [](
ReplacementItem &item) {
 
 1067     return std::holds_alternative<Placeholder>(item);
 
 1070   if (numPlaceholders != getFmtArgs().size()) {
 
 1071     return emitOpError()
 
 1072            << 
"requires operands for each placeholder in the format string";
 
 1086 #include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc" 
 1092 #define GET_ATTRDEF_CLASSES 
 1093 #include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc" 
 1099 #define GET_TYPEDEF_CLASSES 
 1100 #include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc" 
 1121   if (!isValidElementType(elementType))
 
 1122     return parser.
emitError(typeLoc, 
"invalid array element type '")
 
 1123                << elementType << 
"'",
 
 1127   return parser.
getChecked<ArrayType>(dimensions, elementType);
 
 1133     printer << dim << 
'x';
 
 1143     return emitError() << 
"shape must not be empty";
 
 1145   for (int64_t dim : shape) {
 
 1147       return emitError() << 
"dimensions must have non-negative size";
 
 1151     return emitError() << 
"element type must not be none";
 
 1153   if (!isValidElementType(elementType))
 
 1154     return emitError() << 
"invalid array element type";
 
 1161                             Type elementType)
 const {
 
 1178            << 
"!emitc.lvalue must wrap supported emitc type, but got " << value;
 
 1180   if (llvm::isa<emitc::ArrayType>(value))
 
 1181     return emitError() << 
"!emitc.lvalue cannot wrap !emitc.array type";
 
 1192     llvm::StringRef value) {
 
 1193   if (value.empty()) {
 
 1194     return emitError() << 
"expected non empty string in !emitc.opaque type";
 
 1196   if (value.back() == 
'*') {
 
 1197     return emitError() << 
"pointer not allowed as outer type with " 
 1198                           "!emitc.opaque, use !emitc.ptr instead";
 
 1209   if (llvm::isa<emitc::LValueType>(value))
 
 1210     return emitError() << 
"pointers to lvalues are not allowed";
 
 1229   if (
auto array = llvm::dyn_cast<ArrayType>(type))
 
 1249   if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
 
 1252            << 
"initial value should be a integer, float, elements or opaque " 
 1259     return emitOpError(
"expected valid emitc type");
 
 1261   if (getInitialValue().has_value()) {
 
 1262     Attribute initValue = getInitialValue().value();
 
 1265     if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
 
 1266       auto arrayType = llvm::dyn_cast<ArrayType>(
getType());
 
 1268         return emitOpError(
"expected array type, but got ") << 
getType();
 
 1270       Type initType = elementsAttr.getType();
 
 1272       if (initType != tensorType) {
 
 1273         return emitOpError(
"initial value expected to be of type ")
 
 1274                << 
getType() << 
", but was of type " << initType;
 
 1276     } 
else if (
auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
 
 1277       if (intAttr.getType() != 
getType()) {
 
 1278         return emitOpError(
"initial value expected to be of type ")
 
 1279                << 
getType() << 
", but was of type " << intAttr.getType();
 
 1281     } 
else if (
auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
 
 1282       if (floatAttr.getType() != 
getType()) {
 
 1283         return emitOpError(
"initial value expected to be of type ")
 
 1284                << 
getType() << 
", but was of type " << floatAttr.getType();
 
 1286     } 
else if (!isa<emitc::OpaqueAttr>(initValue)) {
 
 1287       return emitOpError(
"initial value should be a integer, float, elements " 
 1288                          "or opaque attribute, but got ")
 
 1292   if (getStaticSpecifier() && getExternSpecifier()) {
 
 1293     return emitOpError(
"cannot have both static and extern specifiers");
 
 1308     return emitOpError(
"'")
 
 1309            << getName() << 
"' does not reference a valid emitc.global";
 
 1311   Type resultType = getResult().getType();
 
 1312   Type globalType = global.getType();
 
 1315   if (llvm::isa<ArrayType>(globalType)) {
 
 1316     if (globalType != resultType)
 
 1317       return emitOpError(
"on array type expects result type ")
 
 1318              << resultType << 
" to match type " << globalType
 
 1319              << 
" of the global @" << getName();
 
 1324   auto lvalueType = dyn_cast<LValueType>(resultType);
 
 1326     return emitOpError(
"on non-array type expects result type to be an " 
 1327                        "lvalue type for the global @")
 
 1329   if (lvalueType.getValueType() != globalType)
 
 1330     return emitOpError(
"on non-array type expects result inner type ")
 
 1331            << lvalueType.getValueType() << 
" to match type " << globalType
 
 1332            << 
" of the global @" << getName();
 
 1347     Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
 
 1351     caseValues.push_back(value);
 
 1360   for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
 
 1362     p << 
"case " << value << 
' ';
 
 1368                                   const Twine &name) {
 
 1369   auto yield = dyn_cast<emitc::YieldOp>(region.
front().
back());
 
 1371     return op.emitOpError(
"expected region to end with emitc.yield, but got ")
 
 1374   if (yield.getNumOperands() != 0) {
 
 1375     return (op.emitOpError(
"expected each region to return ")
 
 1376             << 
"0 values, but " << name << 
" returns " 
 1377             << yield.getNumOperands())
 
 1378                .attachNote(yield.getLoc())
 
 1379            << 
"see yield operation here";
 
 1387     return emitOpError(
"unsupported type ") << getArg().getType();
 
 1389   if (getCases().size() != getCaseRegions().size()) {
 
 1390     return emitOpError(
"has ")
 
 1391            << getCaseRegions().size() << 
" case regions but " 
 1392            << getCases().size() << 
" case values";
 
 1396   for (int64_t value : getCases())
 
 1397     if (!valueSet.insert(value).second)
 
 1398       return emitOpError(
"has duplicate case value: ") << value;
 
 1410 unsigned emitc::SwitchOp::getNumCases() { 
return getCases().size(); }
 
 1412 Block &emitc::SwitchOp::getDefaultBlock() { 
return getDefaultRegion().
front(); }
 
 1414 Block &emitc::SwitchOp::getCaseBlock(
unsigned idx) {
 
 1415   assert(idx < getNumCases() && 
"case index out-of-bounds");
 
 1416   return getCaseRegions()[idx].front();
 
 1419 void SwitchOp::getSuccessorRegions(
 
 1421   llvm::append_range(successors, getRegions());
 
 1424 void SwitchOp::getEntrySuccessorRegions(
 
 1427   FoldAdaptor adaptor(operands, *
this);
 
 1430   auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
 
 1432     llvm::append_range(successors, getRegions());
 
 1438   for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
 
 1439     if (caseValue == arg.getInt()) {
 
 1440       successors.emplace_back(&caseRegion);
 
 1444   successors.emplace_back(&getDefaultRegion());
 
 1447 void SwitchOp::getRegionInvocationBounds(
 
 1449   auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
 
 1450   if (!operandValue) {
 
 1456   unsigned liveIndex = getNumRegions() - 1;
 
 1457   const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
 
 1459   liveIndex = iteratorToInt != getCases().end()
 
 1460                   ? std::distance(getCases().begin(), iteratorToInt)
 
 1463   for (
unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
 
 1465     bounds.emplace_back(0, regIndex == liveIndex);
 
 1472   state.addRegion()->emplaceBlock();
 
 1473   state.attributes.push_back(
 
 1492   if (
auto array = llvm::dyn_cast<ArrayType>(type))
 
 1512   if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
 
 1515            << 
"initial value should be a integer, float, elements or opaque " 
 1522     return emitOpError(
"expected valid emitc type");
 
 1525   if (!parentOp || !isa<emitc::ClassOp>(parentOp))
 
 1526     return emitOpError(
"field must be nested within an emitc.class operation");
 
 1528   StringAttr symName = getSymNameAttr();
 
 1529   if (!symName || symName.getValue().empty())
 
 1530     return emitOpError(
"field must have a non-empty symbol name");
 
 1540   auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
 
 1541   if (!parentClassOp.getOperation())
 
 1542     return emitOpError(
" must be nested within an emitc.class operation");
 
 1552     return emitOpError(
"field '")
 
 1553            << fieldNameAttr << 
"' not found in the class";
 
 1555   Type getFieldResultType = getResult().getType();
 
 1556   Type fieldType = fieldOp.getType();
 
 1558   if (fieldType != getFieldResultType)
 
 1559     return emitOpError(
"result type ")
 
 1560            << getFieldResultType << 
" does not match field '" << fieldNameAttr
 
 1561            << 
"' type " << fieldType;
 
 1579   Block &condBlock = getConditionRegion().
front();
 
 1583                "condition region must contain exactly two operations: " 
 1584                "'emitc.expression' followed by 'emitc.yield', but found ")
 
 1588   auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
 
 1590     return emitOpError(
"expected first op in condition region to be " 
 1591                        "'emitc.expression', but got ")
 
 1594   if (!exprOp.getResult().getType().isInteger(1))
 
 1595     return emitOpError(
"emitc.expression in condition region must return " 
 1596                        "'i1', but returns ")
 
 1597            << exprOp.getResult().getType();
 
 1600   auto condYield = dyn_cast<emitc::YieldOp>(last);
 
 1602     return emitOpError(
"expected last op in condition region to be " 
 1603                        "'emitc.yield', but got ")
 
 1606   if (condYield.getNumOperands() != 1)
 
 1607     return emitOpError(
"expected condition region to return 1 value, but " 
 1609            << condYield.getNumOperands() << 
" values";
 
 1611   if (condYield.getOperand(0) != exprOp.getResult())
 
 1612     return emitError(
"'emitc.yield' must return result of " 
 1613                      "'emitc.expression' from this condition region");
 
 1617     return emitOpError(
"body region must not contain terminator");
 
 1630   if (bodyRegion->
empty())
 
 1640 #include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc" 
 1642 #define GET_OP_CLASSES 
 1643 #include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc" 
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static bool hasSideEffects(Operation *op)
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region >> &caseRegions)
Parse the case regions and values.
static Type getInitializerTypeForField(Type type)
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, llvm::function_ref< mlir::InFlightDiagnostic()> emitError={})
Parse a format string and return a list of its parts.
static ParseResult parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue)
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static Type getInitializerTypeForGlobal(Type type)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
OpListType & getOperations()
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse)=0
Renumber the arguments for the specified region to the same names as the SSA values in namesToUse.
virtual void printNewline()=0
Print a newline and indent the printer to the start of the current operation.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
type_range getType() const
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
bool use_empty() const
Returns true if this value has no uses.
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
ArrayRef< T > asArrayRef() const
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
std::variant< StringRef, Placeholder > ReplacementItem
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
bool isFundamentalType(mlir::Type type)
Determines whether type is a valid fundamental C++ type in EmitC.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
UnresolvedOperand ssaName
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
Region * addRegion()
Create a region that should be attached to the operation.