18#include "llvm/ADT/STLExtras.h"
19#include "llvm/ADT/SmallVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Casting.h"
26#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
32void EmitCDialect::initialize() {
35#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
38#define GET_TYPEDEF_LIST
39#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
42#define GET_ATTRDEF_LIST
43#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
52 return emitc::ConstantOp::create(builder, loc, type, value);
58 emitc::YieldOp::create(builder, loc);
62 if (llvm::isa<emitc::OpaqueType>(type))
64 if (
auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
66 if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
67 auto elemType = arrayType.getElementType();
68 return !llvm::isa<emitc::ArrayType>(elemType) &&
73 if (llvm::isa<IntegerType>(type))
75 if (llvm::isa<FloatType>(type))
77 if (
auto tensorType = llvm::dyn_cast<TensorType>(type)) {
78 if (!tensorType.hasStaticShape()) {
81 auto elemType = tensorType.getElementType();
82 if (llvm::isa<emitc::ArrayType>(elemType)) {
87 if (
auto tupleType = llvm::dyn_cast<TupleType>(type)) {
88 return llvm::all_of(tupleType.getTypes(), [](
Type type) {
89 return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
96 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
97 switch (intType.getWidth()) {
112 return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
117 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
118 switch (floatType.getWidth()) {
120 return llvm::isa<Float16Type, BFloat16Type>(type);
132 return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
139 isa<emitc::PointerType>(type);
146 assert(op->
getNumResults() == 1 &&
"operation must have 1 result");
148 if (llvm::isa<emitc::OpaqueAttr>(value))
151 if (llvm::isa<StringAttr>(value))
153 <<
"string attributes are not supported, use #emitc.opaque instead";
156 if (
auto lType = dyn_cast<LValueType>(resultType))
157 resultType = lType.getValueType();
158 Type attrType = cast<TypedAttr>(value).getType();
163 if (resultType != attrType)
165 <<
"requires attribute to either be an #emitc.opaque attribute or "
167 << attrType <<
") to match the op's result type (" << resultType
178template <
class ArgType>
180 StringRef toParse, ArgType fmtArgs,
185 if (fmtArgs.empty()) {
186 items.push_back(toParse);
190 while (!toParse.empty()) {
191 size_t idx = toParse.find(
'{');
192 if (idx == StringRef::npos) {
194 items.push_back(toParse);
199 items.push_back(toParse.take_front(idx));
200 toParse = toParse.drop_front(idx);
203 if (toParse.size() < 2) {
204 return emitError() <<
"expected '}' after unescaped '{' at end of string";
207 char nextChar = toParse[1];
208 if (nextChar ==
'{') {
210 items.push_back(toParse.take_front(1));
211 toParse = toParse.drop_front(2);
214 if (nextChar ==
'}') {
216 toParse = toParse.drop_front(2);
221 return emitError() <<
"expected '}' after unescaped '{'";
232LogicalResult AddressOfOp::verify() {
233 emitc::LValueType referenceType = getReference().getType();
234 emitc::PointerType resultType = getResult().getType();
236 if (referenceType.getValueType() != resultType.getPointee())
237 return emitOpError(
"requires result to be a pointer to the type "
238 "referenced by operand");
247LogicalResult AddOp::verify() {
248 Type lhsType = getLhs().getType();
249 Type rhsType = getRhs().getType();
251 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
252 return emitOpError(
"requires that at most one operand is a pointer");
254 if ((isa<emitc::PointerType>(lhsType) &&
255 !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
256 (isa<emitc::PointerType>(rhsType) &&
257 !isa<IntegerType, emitc::OpaqueType>(lhsType)))
258 return emitOpError(
"requires that one operand is an integer or of opaque "
259 "type if the other is a pointer");
268LogicalResult ApplyOp::verify() {
269 StringRef applicableOperatorStr = getApplicableOperator();
272 if (applicableOperatorStr.empty())
273 return emitOpError(
"applicable operator must not be empty");
276 if (applicableOperatorStr !=
"&" && applicableOperatorStr !=
"*")
277 return emitOpError(
"applicable operator is illegal");
279 Type operandType = getOperand().getType();
280 Type resultType = getResult().getType();
281 if (applicableOperatorStr ==
"&") {
282 if (!llvm::isa<emitc::LValueType>(operandType))
283 return emitOpError(
"operand type must be an lvalue when applying `&`");
284 if (!llvm::isa<emitc::PointerType>(resultType))
285 return emitOpError(
"result type must be a pointer when applying `&`");
287 if (!llvm::isa<emitc::PointerType>(operandType))
288 return emitOpError(
"operand type must be a pointer when applying `*`");
300LogicalResult emitc::AssignOp::verify() {
303 if (!variable.getDefiningOp())
304 return emitOpError() <<
"cannot assign to block argument";
306 Type valueType = getValue().getType();
307 Type variableType = variable.getType().getValueType();
308 if (variableType != valueType)
309 return emitOpError() <<
"requires value's type (" << valueType
310 <<
") to match variable's type (" << variableType
311 <<
")\n variable: " << variable
312 <<
"\n value: " << getValue() <<
"\n";
321 Type input = inputs.front(), output = outputs.front();
323 if (
auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
324 if (
auto pointerType = dyn_cast<emitc::PointerType>(output)) {
325 return (arrayType.getElementType() == pointerType.getPointee()) &&
326 arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
342LogicalResult emitc::CallOpaqueOp::verify() {
344 if (getCallee().empty())
347 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
349 auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
350 if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
355 return emitOpError(
"index argument is out of range");
358 }
else if (llvm::isa<ArrayAttr>(
366 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
367 for (
Attribute tArg : *templateArgsAttr) {
368 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
369 return emitOpError(
"template argument has invalid type");
373 if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
374 return emitOpError() <<
"cannot return array type";
384LogicalResult emitc::ConstantOp::verify() {
388 if (
auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
389 if (opaqueValue.getValue().empty())
395OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
401LogicalResult DereferenceOp::verify() {
402 emitc::PointerType pointerType = getPointer().getType();
405 return emitOpError(
"requires result to be an lvalue of the type "
406 "pointed to by operand");
420 result.addAttribute(ExpressionOp::getDoNotInlineAttrName(
result.name),
425 "expected function type");
426 auto fnType = llvm::dyn_cast<FunctionType>(type);
429 "expected function type");
433 if (fnType.getNumResults() != 1)
435 "expected single return type");
436 result.addTypes(fnType.getResults());
439 for (
auto [unresolvedOperand, operandType] :
440 llvm::zip(operands, fnType.getInputs())) {
442 argInfo.
ssaName = unresolvedOperand;
443 argInfo.
type = operandType;
444 argsInfo.push_back(argInfo);
462 auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
463 Value yieldedValue = yieldOp.getResult();
467LogicalResult ExpressionOp::verify() {
468 Type resultType = getResult().getType();
469 Region ®ion = getRegion();
474 return emitOpError(
"must yield a value at termination");
477 Value yieldResult = yield.getResult();
480 return emitOpError(
"must yield a value at termination");
485 return emitOpError(
"yielded value has no defining op");
488 return emitOpError(
"yielded value not defined within expression");
492 if (resultType != yieldType)
493 return emitOpError(
"requires yielded type to match return type");
496 auto expressionInterface = dyn_cast<emitc::CExpressionInterface>(op);
497 if (!expressionInterface)
498 return emitOpError(
"contains an unsupported operation");
499 if (op.getNumResults() != 1)
500 return emitOpError(
"requires exactly one result for each operation");
503 return emitOpError(
"contains an unused operation");
510 worklist.push_back(rootOp);
511 while (!worklist.empty()) {
514 if (visited.contains(op)) {
517 "requires exactly one use for operations with side effects");
521 if (
Operation *def = operand.getDefiningOp()) {
522 worklist.push_back(def);
544 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
571 regionArgs.push_back(inductionVariable);
580 regionArgs.front().type = type;
591 ForOp::ensureTerminator(*body, builder,
result.location);
601 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
606 p <<
" : " << t <<
' ';
613LogicalResult ForOp::verifyRegions() {
616 if (getBody()->getNumArguments() != 1)
617 return emitOpError(
"expected body to have a single block argument for the "
618 "induction variable");
622 "expected induction variable to be same type as bounds and step");
635 return emitOpError(
"requires a 'callee' symbol reference attribute");
639 <<
"' does not reference a valid function";
642 auto fnType = fn.getFunctionType();
643 if (fnType.getNumInputs() != getNumOperands())
644 return emitOpError(
"incorrect number of operands for callee");
646 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
647 if (getOperand(i).
getType() != fnType.getInput(i))
648 return emitOpError(
"operand type mismatch: expected operand type ")
649 << fnType.getInput(i) <<
", but provided "
650 << getOperand(i).getType() <<
" for operand number " << i;
652 if (fnType.getNumResults() != getNumResults())
653 return emitOpError(
"incorrect number of results for callee");
655 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
656 if (getResult(i).
getType() != fnType.getResult(i)) {
658 diag.attachNote() <<
" op result types: " << getResultTypes();
659 diag.attachNote() <<
"function result types: " << fnType.getResults();
666FunctionType CallOp::getCalleeType() {
667 return FunctionType::get(
getContext(), getOperandTypes(), getResultTypes());
677 auto fnAttr = getSymNameAttr();
679 return emitOpError(
"requires a 'sym_name' symbol reference attribute");
683 <<
"' does not reference a valid function";
697 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
701 if (argAttrs.empty())
703 assert(type.getNumInputs() == argAttrs.size());
705 builder, state, argAttrs, {},
706 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
717 getFunctionTypeAttrName(
result.name), buildFuncType,
718 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
723 p, *
this,
false, getFunctionTypeAttrName(),
724 getArgAttrsAttrName(), getResAttrsAttrName());
727LogicalResult FuncOp::verify() {
728 if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
729 return emitOpError(
"cannot have lvalue type as argument");
732 if (getNumResults() > 1)
733 return emitOpError(
"requires zero or exactly one result, but has ")
736 if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
746LogicalResult ReturnOp::verify() {
747 auto function = cast<FuncOp>((*this)->getParentOp());
750 if (getNumOperands() != function.getNumResults())
752 << getNumOperands() <<
" operands, but enclosing function (@"
753 << function.getName() <<
") returns " << function.getNumResults();
755 if (function.getNumResults() == 1)
756 if (getOperand().
getType() != function.getResultTypes()[0])
757 return emitError() <<
"type of the return operand ("
758 << getOperand().getType()
759 <<
") doesn't match function result type ("
760 << function.getResultTypes()[0] <<
")"
761 <<
" in function @" << function.getName();
770 bool addThenBlock,
bool addElseBlock) {
771 assert((!addElseBlock || addThenBlock) &&
772 "must not create else block w/o then block");
786 bool withElseRegion) {
796 if (withElseRegion) {
804 assert(thenBuilder &&
"the builder callback for 'then' must be present");
811 thenBuilder(builder,
result.location);
817 elseBuilder(builder,
result.location);
823 result.regions.reserve(2);
852 bool printBlockTerminators =
false;
854 p <<
" " << getCondition();
858 printBlockTerminators);
861 Region &elseRegion = getElseRegion();
862 if (!elseRegion.
empty()) {
866 printBlockTerminators);
888 Region *elseRegion = &this->getElseRegion();
889 if (elseRegion->
empty())
902 FoldAdaptor adaptor(operands, *
this);
903 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
904 if (!boolAttr || boolAttr.getValue())
905 regions.emplace_back(&getThenRegion());
908 if (!boolAttr || !boolAttr.getValue()) {
909 if (!getElseRegion().empty())
910 regions.emplace_back(&getElseRegion());
916void IfOp::getRegionInvocationBounds(
919 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
922 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
923 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
926 invocationBounds.assign(2, {0, 1});
935 bool standardInclude = getIsStandardInclude();
940 p <<
"\"" << getInclude() <<
"\"";
956 <<
"expected trailing '>' for standard include";
959 result.addAttribute(
"is_standard_include",
970LogicalResult emitc::LiteralOp::verify() {
971 if (getValue().empty())
979LogicalResult SubOp::verify() {
980 Type lhsType = getLhs().getType();
981 Type rhsType = getRhs().getType();
982 Type resultType = getResult().getType();
984 if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
985 return emitOpError(
"rhs can only be a pointer if lhs is a pointer");
987 if (isa<emitc::PointerType>(lhsType) &&
988 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
989 return emitOpError(
"requires that rhs is an integer, pointer or of opaque "
990 "type if lhs is a pointer");
992 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
993 !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
994 return emitOpError(
"requires that the result is an integer, ptrdiff_t or "
995 "of opaque type if lhs and rhs are pointers");
1003LogicalResult emitc::VariableOp::verify() {
1011LogicalResult emitc::YieldOp::verify() {
1016 return emitOpError() <<
"yields a value not returned by parent";
1019 return emitOpError() <<
"does not yield a value to be returned by parent";
1028LogicalResult emitc::SubscriptOp::verify() {
1030 if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().
getType())) {
1032 if (
getIndices().size() != (
size_t)arrayType.getRank()) {
1033 return emitOpError() <<
"on array operand requires number of indices ("
1035 <<
") to match the rank of the array type ("
1036 << arrayType.getRank() <<
")";
1039 for (
unsigned i = 0, e =
getIndices().size(); i != e; ++i) {
1042 return emitOpError() <<
"on array operand requires index operand " << i
1043 <<
" to be integer-like, but got " << type;
1047 Type elementType = arrayType.getElementType();
1049 if (elementType != resultType) {
1050 return emitOpError() <<
"on array operand requires element type ("
1051 << elementType <<
") and result type (" << resultType
1058 if (
auto pointerType =
1059 llvm::dyn_cast<emitc::PointerType>(getValue().
getType())) {
1063 <<
"on pointer operand requires one index operand, but got "
1069 return emitOpError() <<
"on pointer operand requires index operand to be "
1070 "integer-like, but got "
1074 Type pointeeType = pointerType.getPointee();
1076 if (pointeeType != resultType) {
1077 return emitOpError() <<
"on pointer operand requires pointee type ("
1078 << pointeeType <<
") and result type (" << resultType
1093LogicalResult emitc::VerbatimOp::verify() {
1097 FailureOr<SmallVector<ReplacementItem>> fmt =
1102 size_t numPlaceholders = llvm::count_if(*fmt, [](
ReplacementItem &item) {
1103 return std::holds_alternative<Placeholder>(item);
1106 if (numPlaceholders != getFmtArgs().size()) {
1108 <<
"requires operands for each placeholder in the format string";
1113FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1115 return ::parseFormatString(getValue(), getFmtArgs());
1122#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1128#define GET_ATTRDEF_CLASSES
1129#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1135#define GET_TYPEDEF_CLASSES
1136#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1157 if (!isValidElementType(elementType))
1158 return parser.
emitError(typeLoc,
"invalid array element type '")
1159 << elementType <<
"'",
1163 return parser.
getChecked<ArrayType>(dimensions, elementType);
1166void emitc::ArrayType::print(
AsmPrinter &printer)
const {
1169 printer << dim <<
'x';
1175LogicalResult emitc::ArrayType::verify(
1179 return emitError() <<
"shape must not be empty";
1183 return emitError() <<
"dimensions must have non-negative size";
1187 return emitError() <<
"element type must not be none";
1189 if (!isValidElementType(elementType))
1190 return emitError() <<
"invalid array element type";
1197 Type elementType)
const {
1199 return emitc::ArrayType::get(
getShape(), elementType);
1200 return emitc::ArrayType::get(*
shape, elementType);
1207LogicalResult mlir::emitc::LValueType::verify(
1214 <<
"!emitc.lvalue must wrap supported emitc type, but got " << value;
1216 if (llvm::isa<emitc::ArrayType>(value))
1217 return emitError() <<
"!emitc.lvalue cannot wrap !emitc.array type";
1226LogicalResult mlir::emitc::OpaqueType::verify(
1228 llvm::StringRef value) {
1229 if (value.empty()) {
1230 return emitError() <<
"expected non empty string in !emitc.opaque type";
1232 if (value.back() ==
'*') {
1233 return emitError() <<
"pointer not allowed as outer type with "
1234 "!emitc.opaque, use !emitc.ptr instead";
1243LogicalResult mlir::emitc::PointerType::verify(
1245 if (llvm::isa<emitc::LValueType>(value))
1246 return emitError() <<
"pointers to lvalues are not allowed";
1265 if (
auto array = llvm::dyn_cast<ArrayType>(type))
1266 return RankedTensorType::get(array.getShape(), array.getElementType());
1277 typeAttr = TypeAttr::get(type);
1285 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1288 <<
"initial value should be a integer, float, elements or opaque "
1293LogicalResult GlobalOp::verify() {
1297 if (getInitialValue().has_value()) {
1298 Attribute initValue = getInitialValue().value();
1301 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1302 auto arrayType = llvm::dyn_cast<ArrayType>(
getType());
1306 Type initType = elementsAttr.getType();
1308 if (initType != tensorType) {
1309 return emitOpError(
"initial value expected to be of type ")
1310 <<
getType() <<
", but was of type " << initType;
1312 }
else if (
auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1313 if (intAttr.getType() !=
getType()) {
1314 return emitOpError(
"initial value expected to be of type ")
1315 <<
getType() <<
", but was of type " << intAttr.getType();
1317 }
else if (
auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1318 if (floatAttr.getType() !=
getType()) {
1319 return emitOpError(
"initial value expected to be of type ")
1320 <<
getType() <<
", but was of type " << floatAttr.getType();
1322 }
else if (!isa<emitc::OpaqueAttr>(initValue)) {
1323 return emitOpError(
"initial value should be a integer, float, elements "
1324 "or opaque attribute, but got ")
1328 if (getStaticSpecifier() && getExternSpecifier()) {
1329 return emitOpError(
"cannot have both static and extern specifiers");
1345 << getName() <<
"' does not reference a valid emitc.global";
1347 Type resultType = getResult().getType();
1348 Type globalType = global.getType();
1351 if (llvm::isa<ArrayType>(globalType)) {
1352 if (globalType != resultType)
1353 return emitOpError(
"on array type expects result type ")
1354 << resultType <<
" to match type " << globalType
1355 <<
" of the global @" << getName();
1360 auto lvalueType = dyn_cast<LValueType>(resultType);
1362 return emitOpError(
"on non-array type expects result type to be an "
1363 "lvalue type for the global @")
1365 if (lvalueType.getValueType() != globalType)
1366 return emitOpError(
"on non-array type expects result inner type ")
1367 << lvalueType.getValueType() <<
" to match type " << globalType
1368 <<
" of the global @" << getName();
1383 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
1387 caseValues.push_back(value);
1396 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
1398 p <<
"case " << value <<
' ';
1404 const Twine &name) {
1405 auto yield = dyn_cast<emitc::YieldOp>(region.
front().
back());
1407 return op.emitOpError(
"expected region to end with emitc.yield, but got ")
1410 if (yield.getNumOperands() != 0) {
1411 return (op.emitOpError(
"expected each region to return ")
1412 <<
"0 values, but " << name <<
" returns "
1413 << yield.getNumOperands())
1414 .attachNote(yield.getLoc())
1415 <<
"see yield operation here";
1421LogicalResult emitc::SwitchOp::verify() {
1423 return emitOpError(
"unsupported type ") << getArg().getType();
1425 if (getCases().size() != getCaseRegions().size()) {
1427 << getCaseRegions().size() <<
" case regions but "
1428 << getCases().size() <<
" case values";
1432 for (
int64_t value : getCases())
1433 if (!valueSet.insert(value).second)
1434 return emitOpError(
"has duplicate case value: ") << value;
1439 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1446unsigned emitc::SwitchOp::getNumCases() {
return getCases().size(); }
1448Block &emitc::SwitchOp::getDefaultBlock() {
return getDefaultRegion().
front(); }
1450Block &emitc::SwitchOp::getCaseBlock(
unsigned idx) {
1451 assert(idx < getNumCases() &&
"case index out-of-bounds");
1452 return getCaseRegions()[idx].front();
1455void SwitchOp::getSuccessorRegions(
1457 llvm::append_range(successors, getRegions());
1460void SwitchOp::getEntrySuccessorRegions(
1463 FoldAdaptor adaptor(operands, *
this);
1466 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1468 llvm::append_range(successors, getRegions());
1474 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1475 if (caseValue == arg.getInt()) {
1476 successors.emplace_back(&caseRegion);
1480 successors.emplace_back(&getDefaultRegion());
1483void SwitchOp::getRegionInvocationBounds(
1485 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1486 if (!operandValue) {
1492 unsigned liveIndex = getNumRegions() - 1;
1493 const auto *iteratorToInt = llvm::find(getCases(), operandValue.getInt());
1495 liveIndex = iteratorToInt != getCases().end()
1496 ? std::distance(getCases().begin(), iteratorToInt)
1499 for (
unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1501 bounds.emplace_back(0, regIndex == liveIndex);
1528 if (
auto array = llvm::dyn_cast<ArrayType>(type))
1529 return RankedTensorType::get(array.getShape(), array.getElementType());
1540 typeAttr = TypeAttr::get(type);
1548 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1551 <<
"initial value should be a integer, float, elements or opaque "
1556LogicalResult FieldOp::verify() {
1561 if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1562 return emitOpError(
"field must be nested within an emitc.class operation");
1564 StringAttr symName = getSymNameAttr();
1565 if (!symName || symName.getValue().empty())
1566 return emitOpError(
"field must have a non-empty symbol name");
1575LogicalResult GetFieldOp::verify() {
1576 auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
1577 if (!parentClassOp.getOperation())
1578 return emitOpError(
" must be nested within an emitc.class operation");
1589 << fieldNameAttr <<
"' not found in the class";
1591 Type getFieldResultType = getResult().getType();
1592 Type fieldType = fieldOp.getType();
1594 if (fieldType != getFieldResultType)
1596 << getFieldResultType <<
" does not match field '" << fieldNameAttr
1597 <<
"' type " << fieldType;
1614LogicalResult emitc::DoOp::verify() {
1615 Block &condBlock = getConditionRegion().
front();
1619 "condition region must contain exactly two operations: "
1620 "'emitc.expression' followed by 'emitc.yield', but found ")
1624 auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
1626 return emitOpError(
"expected first op in condition region to be "
1627 "'emitc.expression', but got ")
1630 if (!exprOp.getResult().getType().isInteger(1))
1631 return emitOpError(
"emitc.expression in condition region must return "
1632 "'i1', but returns ")
1633 << exprOp.getResult().getType();
1636 auto condYield = dyn_cast<emitc::YieldOp>(last);
1638 return emitOpError(
"expected last op in condition region to be "
1639 "'emitc.yield', but got ")
1642 if (condYield.getNumOperands() != 1)
1643 return emitOpError(
"expected condition region to return 1 value, but "
1645 << condYield.getNumOperands() <<
" values";
1647 if (condYield.getOperand(0) != exprOp.getResult())
1648 return emitError(
"'emitc.yield' must return result of "
1649 "'emitc.expression' from this condition region");
1653 return emitOpError(
"body region must not contain terminator");
1666 if (bodyRegion->
empty())
1676#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1678#define GET_OP_CLASSES
1679#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static bool hasSideEffects(Operation *op)
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static Type getInitializerTypeForField(Type type)
static ParseResult parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, llvm::function_ref< mlir::InFlightDiagnostic()> emitError={})
Parse a format string and return a list of its parts.
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue)
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static Type getInitializerTypeForGlobal(Type type)
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
A symbol reference with a reference path containing a single element.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse)=0
Renumber the arguments for the specified region to the same names as the SSA values in namesToUse.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
This class represents a single result from folding an operation.
type_range getType() const
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
std::variant< StringRef, Placeholder > ReplacementItem
bool isFundamentalType(mlir::Type type)
Determines whether type is a valid fundamental C++ type in EmitC.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::function_ref< Fn > function_ref
UnresolvedOperand ssaName
This is the representation of an operand reference.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.