19#include "llvm/ADT/STLExtras.h"
20#include "llvm/ADT/SmallVector.h"
21#include "llvm/ADT/TypeSwitch.h"
22#include "llvm/Support/Casting.h"
27#include "mlir/Dialect/EmitC/IR/EmitCDialect.cpp.inc"
33void EmitCDialect::initialize() {
36#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
39#define GET_TYPEDEF_LIST
40#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
43#define GET_ATTRDEF_LIST
44#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
53 return emitc::ConstantOp::create(builder, loc, type, value);
59 emitc::YieldOp::create(builder, loc);
63 if (llvm::isa<emitc::OpaqueType>(type))
65 if (
auto ptrType = llvm::dyn_cast<emitc::PointerType>(type))
67 if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(type)) {
68 auto elemType = arrayType.getElementType();
69 return !llvm::isa<emitc::ArrayType>(elemType) &&
74 if (llvm::isa<IntegerType>(type))
76 if (llvm::isa<FloatType>(type))
78 if (
auto tensorType = llvm::dyn_cast<TensorType>(type)) {
79 if (!tensorType.hasStaticShape()) {
82 auto elemType = tensorType.getElementType();
83 if (llvm::isa<emitc::ArrayType>(elemType)) {
88 if (
auto tupleType = llvm::dyn_cast<TupleType>(type)) {
89 return llvm::all_of(tupleType.getTypes(), [](
Type type) {
90 return !llvm::isa<emitc::ArrayType>(type) && isSupportedEmitCType(type);
97 if (
auto intType = llvm::dyn_cast<IntegerType>(type)) {
98 switch (intType.getWidth()) {
113 return llvm::isa<IndexType, emitc::OpaqueType>(type) ||
118 if (
auto floatType = llvm::dyn_cast<FloatType>(type)) {
119 switch (floatType.getWidth()) {
121 return llvm::isa<Float16Type, BFloat16Type>(type);
133 return isa<emitc::SignedSizeTType, emitc::SizeTType, emitc::PtrDiffTType>(
140 isa<emitc::PointerType>(type);
147 assert(op->
getNumResults() == 1 &&
"operation must have 1 result");
149 if (llvm::isa<emitc::OpaqueAttr>(value))
152 if (llvm::isa<StringAttr>(value))
154 <<
"string attributes are not supported, use #emitc.opaque instead";
157 if (
auto lType = dyn_cast<LValueType>(resultType))
158 resultType = lType.getValueType();
159 Type attrType = cast<TypedAttr>(value).getType();
164 if (resultType != attrType)
166 <<
"requires attribute to either be an #emitc.opaque attribute or "
168 << attrType <<
") to match the op's result type (" << resultType
179template <
class ArgType>
181 StringRef toParse, ArgType fmtArgs,
186 if (fmtArgs.empty()) {
187 items.push_back(toParse);
191 while (!toParse.empty()) {
192 size_t idx = toParse.find(
'{');
193 if (idx == StringRef::npos) {
195 items.push_back(toParse);
200 items.push_back(toParse.take_front(idx));
201 toParse = toParse.drop_front(idx);
204 if (toParse.size() < 2) {
205 return emitError() <<
"expected '}' after unescaped '{' at end of string";
208 char nextChar = toParse[1];
209 if (nextChar ==
'{') {
211 items.push_back(toParse.take_front(1));
212 toParse = toParse.drop_front(2);
215 if (nextChar ==
'}') {
217 toParse = toParse.drop_front(2);
222 return emitError() <<
"expected '}' after unescaped '{'";
233LogicalResult AddressOfOp::verify() {
234 emitc::LValueType referenceType = getReference().getType();
235 emitc::PointerType resultType = getResult().getType();
237 if (referenceType.getValueType() != resultType.getPointee())
238 return emitOpError(
"requires result to be a pointer to the type "
239 "referenced by operand");
248LogicalResult AddOp::verify() {
249 Type lhsType = getLhs().getType();
250 Type rhsType = getRhs().getType();
252 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType))
253 return emitOpError(
"requires that at most one operand is a pointer");
255 if ((isa<emitc::PointerType>(lhsType) &&
256 !isa<IntegerType, emitc::OpaqueType>(rhsType)) ||
257 (isa<emitc::PointerType>(rhsType) &&
258 !isa<IntegerType, emitc::OpaqueType>(lhsType)))
259 return emitOpError(
"requires that one operand is an integer or of opaque "
260 "type if the other is a pointer");
269LogicalResult ApplyOp::verify() {
270 StringRef applicableOperatorStr = getApplicableOperator();
273 if (applicableOperatorStr.empty())
274 return emitOpError(
"applicable operator must not be empty");
277 if (applicableOperatorStr !=
"&" && applicableOperatorStr !=
"*")
278 return emitOpError(
"applicable operator is illegal");
280 Type operandType = getOperand().getType();
281 Type resultType = getResult().getType();
282 if (applicableOperatorStr ==
"&") {
283 if (!llvm::isa<emitc::LValueType>(operandType))
284 return emitOpError(
"operand type must be an lvalue when applying `&`");
285 if (!llvm::isa<emitc::PointerType>(resultType))
286 return emitOpError(
"result type must be a pointer when applying `&`");
288 if (!llvm::isa<emitc::PointerType>(operandType))
289 return emitOpError(
"operand type must be a pointer when applying `*`");
301LogicalResult emitc::AssignOp::verify() {
304 if (!variable.getDefiningOp())
305 return emitOpError() <<
"cannot assign to block argument";
307 Type valueType = getValue().getType();
308 Type variableType = variable.getType().getValueType();
309 if (variableType != valueType)
310 return emitOpError() <<
"requires value's type (" << valueType
311 <<
") to match variable's type (" << variableType
312 <<
")\n variable: " << variable
313 <<
"\n value: " << getValue() <<
"\n";
322 Type input = inputs.front(), output = outputs.front();
324 if (
auto arrayType = dyn_cast<emitc::ArrayType>(input)) {
325 if (
auto pointerType = dyn_cast<emitc::PointerType>(output)) {
326 return (arrayType.getElementType() == pointerType.getPointee()) &&
327 arrayType.getShape().size() == 1 && arrayType.getShape()[0] >= 1;
343LogicalResult emitc::CallOpaqueOp::verify() {
345 if (getCallee().empty())
348 if (std::optional<ArrayAttr> argsAttr = getArgs()) {
350 auto intAttr = llvm::dyn_cast<IntegerAttr>(arg);
351 if (intAttr && llvm::isa<IndexType>(intAttr.getType())) {
356 return emitOpError(
"index argument is out of range");
359 }
else if (llvm::isa<ArrayAttr>(
367 if (std::optional<ArrayAttr> templateArgsAttr = getTemplateArgs()) {
368 for (
Attribute tArg : *templateArgsAttr) {
369 if (!llvm::isa<TypeAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(tArg))
370 return emitOpError(
"template argument has invalid type");
374 if (llvm::any_of(getResultTypes(), llvm::IsaPred<ArrayType>)) {
375 return emitOpError() <<
"cannot return array type";
385LogicalResult emitc::ConstantOp::verify() {
389 if (
auto opaqueValue = llvm::dyn_cast<emitc::OpaqueAttr>(value)) {
390 if (opaqueValue.getValue().empty())
396OpFoldResult emitc::ConstantOp::fold(FoldAdaptor adaptor) {
return getValue(); }
402LogicalResult DereferenceOp::verify() {
403 emitc::PointerType pointerType = getPointer().getType();
406 return emitOpError(
"requires result to be an lvalue of the type "
407 "pointed to by operand");
418struct RemoveRecurringExpressionOperands
420 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
421 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
422 PatternRewriter &rewriter)
const override {
427 for (
auto [i, operand] : llvm::enumerate(expressionOp.getDefs())) {
428 if (uniqueOperands.contains(operand))
430 uniqueOperands.insert(operand);
431 firstIndexOf[operand] = i;
435 if (uniqueOperands.size() == expressionOp.getDefs().size())
440 auto uniqueExpression = emitc::ExpressionOp::create(
441 rewriter, expressionOp.getLoc(), expressionOp.getResult().getType(),
442 uniqueOperands.getArrayRef(), expressionOp.getDoNotInline());
443 Block &uniqueExpressionBody = uniqueExpression.createBody();
448 Block *expressionBody = expressionOp.getBody();
449 for (
auto [operand, arg] :
450 llvm::zip(expressionOp.getOperands(), expressionBody->
getArguments()))
451 mapper.
map(arg, uniqueExpressionBody.
getArgument(firstIndexOf[operand]));
454 for (Operation &opToClone : *expressionOp.getBody())
455 rewriter.
clone(opToClone, mapper);
458 rewriter.
replaceOp(expressionOp, uniqueExpression);
469 using OpRewritePattern<ExpressionOp>::OpRewritePattern;
470 LogicalResult matchAndRewrite(ExpressionOp expressionOp,
471 PatternRewriter &rewriter)
const override {
472 auto yieldOp = cast<YieldOp>(expressionOp.getBody()->getTerminator());
473 Value yieldedValue = yieldOp.getResult();
474 auto blockArg = dyn_cast_if_present<BlockArgument>(yieldedValue);
478 expressionOp.getOperand(blockArg.getArgNumber()));
487 results.
add<RemoveRecurringExpressionOperands, FoldTrivialExpressionOp>(
496 result.addAttribute(ExpressionOp::getDoNotInlineAttrName(
result.name),
501 "expected function type");
502 auto fnType = llvm::dyn_cast<FunctionType>(type);
505 "expected function type");
509 if (fnType.getNumResults() != 1)
511 "expected single return type");
512 result.addTypes(fnType.getResults());
516 bool enableNameShadowing = uniqueOperands.size() ==
result.operands.size();
518 if (enableNameShadowing) {
519 for (
auto [unresolvedOperand, operandType] :
520 llvm::zip(operands, fnType.getInputs())) {
522 argInfo.
ssaName = unresolvedOperand;
523 argInfo.
type = operandType;
524 argsInfo.push_back(argInfo);
528 if (parser.
parseRegion(*body, argsInfo, enableNameShadowing))
530 if (!enableNameShadowing) {
533 beforeRegionLoc,
"with recurring operands expected block arguments");
541 auto operands = getDefs();
546 bool printEntryBlockArgs =
true;
547 if (uniqueOperands.size() == operands.size()) {
549 printEntryBlockArgs =
false;
556 auto yieldOp = cast<YieldOp>(getBody()->getTerminator());
557 Value yieldedValue = yieldOp.getResult();
561LogicalResult ExpressionOp::verify() {
562 Type resultType = getResult().getType();
563 Region ®ion = getRegion();
568 return emitOpError(
"must yield a value at termination");
571 Value yieldResult = yield.getResult();
574 return emitOpError(
"must yield a value at termination");
579 return emitOpError(
"yielded value has no defining op");
582 return emitOpError(
"yielded value not defined within expression");
586 if (resultType != yieldType)
587 return emitOpError(
"requires yielded type to match return type");
590 auto expressionInterface = dyn_cast<emitc::CExpressionInterface>(op);
591 if (!expressionInterface)
592 return emitOpError(
"contains an unsupported operation");
593 if (op.getNumResults() != 1)
594 return emitOpError(
"requires exactly one result for each operation");
597 return emitOpError(
"contains an unused operation");
604 worklist.push_back(rootOp);
605 while (!worklist.empty()) {
608 if (visited.contains(op)) {
609 auto cExpr = cast<CExpressionInterface>(op);
610 if (!cExpr.alwaysInline() && cExpr.hasSideEffects())
612 "requires exactly one use for operations with side effects");
616 if (
Operation *def = operand.getDefiningOp()) {
617 worklist.push_back(def);
623 if (getDoNotInline() &&
624 cast<emitc::CExpressionInterface>(rootOp).alwaysInline()) {
625 return emitOpError(
"root operation must be inlined but expression is marked"
647 ForOp::ensureTerminator(*bodyRegion, builder,
result.location);
674 regionArgs.push_back(inductionVariable);
683 regionArgs.front().type = type;
694 ForOp::ensureTerminator(*body, builder,
result.location);
704 p <<
" " << getInductionVar() <<
" = " <<
getLowerBound() <<
" to "
709 p <<
" : " << t <<
' ';
716LogicalResult ForOp::verifyRegions() {
719 if (getBody()->getNumArguments() != 1)
720 return emitOpError(
"expected body to have a single block argument for the "
721 "induction variable");
725 "expected induction variable to be same type as bounds and step");
738 return emitOpError(
"requires a 'callee' symbol reference attribute");
742 <<
"' does not reference a valid function";
745 auto fnType = fn.getFunctionType();
746 if (fnType.getNumInputs() != getNumOperands())
747 return emitOpError(
"incorrect number of operands for callee");
749 for (
unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
750 if (getOperand(i).
getType() != fnType.getInput(i))
751 return emitOpError(
"operand type mismatch: expected operand type ")
752 << fnType.getInput(i) <<
", but provided "
753 << getOperand(i).getType() <<
" for operand number " << i;
755 if (fnType.getNumResults() != getNumResults())
756 return emitOpError(
"incorrect number of results for callee");
758 for (
unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
759 if (getResult(i).
getType() != fnType.getResult(i)) {
761 diag.attachNote() <<
" op result types: " << getResultTypes();
762 diag.attachNote() <<
"function result types: " << fnType.getResults();
769FunctionType CallOp::getCalleeType() {
770 return FunctionType::get(
getContext(), getOperandTypes(), getResultTypes());
780 auto fnAttr = getSymNameAttr();
782 return emitOpError(
"requires a 'sym_name' symbol reference attribute");
786 <<
"' does not reference a valid function";
800 state.
addAttribute(getFunctionTypeAttrName(state.
name), TypeAttr::get(type));
804 if (argAttrs.empty())
806 assert(type.getNumInputs() == argAttrs.size());
808 builder, state, argAttrs, {},
809 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
820 getFunctionTypeAttrName(
result.name), buildFuncType,
821 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
826 p, *
this,
false, getFunctionTypeAttrName(),
827 getArgAttrsAttrName(), getResAttrsAttrName());
830LogicalResult FuncOp::verify() {
831 if (llvm::any_of(getArgumentTypes(), llvm::IsaPred<LValueType>)) {
832 return emitOpError(
"cannot have lvalue type as argument");
835 if (getNumResults() > 1)
836 return emitOpError(
"requires zero or exactly one result, but has ")
839 if (getNumResults() == 1 && isa<ArrayType>(getResultTypes()[0]))
849LogicalResult ReturnOp::verify() {
850 auto function = cast<FuncOp>((*this)->getParentOp());
853 if (getNumOperands() != function.getNumResults())
855 << getNumOperands() <<
" operands, but enclosing function (@"
856 << function.getName() <<
") returns " << function.getNumResults();
858 if (function.getNumResults() == 1)
859 if (getOperand().
getType() != function.getResultTypes()[0])
860 return emitError() <<
"type of the return operand ("
861 << getOperand().getType()
862 <<
") doesn't match function result type ("
863 << function.getResultTypes()[0] <<
")"
864 <<
" in function @" << function.getName();
873 bool addThenBlock,
bool addElseBlock) {
874 assert((!addElseBlock || addThenBlock) &&
875 "must not create else block w/o then block");
889 bool withElseRegion) {
899 if (withElseRegion) {
907 assert(thenBuilder &&
"the builder callback for 'then' must be present");
914 thenBuilder(builder,
result.location);
920 elseBuilder(builder,
result.location);
926 result.regions.reserve(2);
955 bool printBlockTerminators =
false;
957 p <<
" " << getCondition();
961 printBlockTerminators);
964 Region &elseRegion = getElseRegion();
965 if (!elseRegion.
empty()) {
969 printBlockTerminators);
991 Region *elseRegion = &this->getElseRegion();
992 if (elseRegion->
empty())
1005 FoldAdaptor adaptor(operands, *
this);
1006 auto boolAttr = dyn_cast_or_null<BoolAttr>(adaptor.getCondition());
1007 if (!boolAttr || boolAttr.getValue())
1008 regions.emplace_back(&getThenRegion());
1011 if (!boolAttr || !boolAttr.getValue()) {
1012 if (!getElseRegion().empty())
1013 regions.emplace_back(&getElseRegion());
1019void IfOp::getRegionInvocationBounds(
1022 if (
auto cond = llvm::dyn_cast_or_null<BoolAttr>(operands[0])) {
1025 invocationBounds.emplace_back(0, cond.getValue() ? 1 : 0);
1026 invocationBounds.emplace_back(0, cond.getValue() ? 0 : 1);
1029 invocationBounds.assign(2, {0, 1});
1038 bool standardInclude = getIsStandardInclude();
1041 if (standardInclude)
1043 p <<
"\"" << getInclude() <<
"\"";
1044 if (standardInclude)
1059 <<
"expected trailing '>' for standard include";
1061 if (standardInclude)
1062 result.addAttribute(
"is_standard_include",
1073LogicalResult emitc::LiteralOp::verify() {
1074 if (getValue().empty())
1075 return emitOpError() <<
"value must not be empty";
1082LogicalResult SubOp::verify() {
1083 Type lhsType = getLhs().getType();
1084 Type rhsType = getRhs().getType();
1085 Type resultType = getResult().getType();
1087 if (isa<emitc::PointerType>(rhsType) && !isa<emitc::PointerType>(lhsType))
1088 return emitOpError(
"rhs can only be a pointer if lhs is a pointer");
1090 if (isa<emitc::PointerType>(lhsType) &&
1091 !isa<IntegerType, emitc::OpaqueType, emitc::PointerType>(rhsType))
1092 return emitOpError(
"requires that rhs is an integer, pointer or of opaque "
1093 "type if lhs is a pointer");
1095 if (isa<emitc::PointerType>(lhsType) && isa<emitc::PointerType>(rhsType) &&
1096 !isa<IntegerType, emitc::PtrDiffTType, emitc::OpaqueType>(resultType))
1097 return emitOpError(
"requires that the result is an integer, ptrdiff_t or "
1098 "of opaque type if lhs and rhs are pointers");
1106LogicalResult emitc::VariableOp::verify() {
1114LogicalResult emitc::YieldOp::verify() {
1119 return emitOpError() <<
"yields a value not returned by parent";
1122 return emitOpError() <<
"does not yield a value to be returned by parent";
1124 if (
result && isa<emitc::LValueType>(
result.getType()) &&
1125 !isa<ExpressionOp>(containingOp))
1126 return emitOpError() <<
"yielding lvalues is not supported for this op";
1135LogicalResult emitc::SubscriptOp::verify() {
1137 if (
auto arrayType = llvm::dyn_cast<emitc::ArrayType>(getValue().
getType())) {
1139 if (
getIndices().size() != (
size_t)arrayType.getRank()) {
1140 return emitOpError() <<
"on array operand requires number of indices ("
1142 <<
") to match the rank of the array type ("
1143 << arrayType.getRank() <<
")";
1146 for (
unsigned i = 0, e =
getIndices().size(); i != e; ++i) {
1149 return emitOpError() <<
"on array operand requires index operand " << i
1150 <<
" to be integer-like, but got " << type;
1154 Type elementType = arrayType.getElementType();
1156 if (elementType != resultType) {
1157 return emitOpError() <<
"on array operand requires element type ("
1158 << elementType <<
") and result type (" << resultType
1165 if (
auto pointerType =
1166 llvm::dyn_cast<emitc::PointerType>(getValue().
getType())) {
1170 <<
"on pointer operand requires one index operand, but got "
1176 return emitOpError() <<
"on pointer operand requires index operand to be "
1177 "integer-like, but got "
1181 Type pointeeType = pointerType.getPointee();
1183 if (pointeeType != resultType) {
1184 return emitOpError() <<
"on pointer operand requires pointee type ("
1185 << pointeeType <<
") and result type (" << resultType
1200LogicalResult emitc::VerbatimOp::verify() {
1204 FailureOr<SmallVector<ReplacementItem>> fmt =
1209 size_t numPlaceholders = llvm::count_if(*fmt, [](
ReplacementItem &item) {
1210 return std::holds_alternative<Placeholder>(item);
1213 if (numPlaceholders != getFmtArgs().size()) {
1215 <<
"requires operands for each placeholder in the format string";
1220FailureOr<SmallVector<ReplacementItem>> emitc::VerbatimOp::parseFormatString() {
1222 return ::parseFormatString(getValue(), getFmtArgs());
1229#include "mlir/Dialect/EmitC/IR/EmitCEnums.cpp.inc"
1235#define GET_ATTRDEF_CLASSES
1236#include "mlir/Dialect/EmitC/IR/EmitCAttributes.cpp.inc"
1242#define GET_TYPEDEF_CLASSES
1243#include "mlir/Dialect/EmitC/IR/EmitCTypes.cpp.inc"
1264 if (!isValidElementType(elementType))
1265 return parser.
emitError(typeLoc,
"invalid array element type '")
1266 << elementType <<
"'",
1270 return parser.
getChecked<ArrayType>(dimensions, elementType);
1273void emitc::ArrayType::print(
AsmPrinter &printer)
const {
1276 printer << dim <<
'x';
1282LogicalResult emitc::ArrayType::verify(
1286 return emitError() <<
"shape must not be empty";
1290 return emitError() <<
"dimensions must have non-negative size";
1294 return emitError() <<
"element type must not be none";
1296 if (!isValidElementType(elementType))
1297 return emitError() <<
"invalid array element type";
1304 Type elementType)
const {
1306 return emitc::ArrayType::get(
getShape(), elementType);
1307 return emitc::ArrayType::get(*
shape, elementType);
1314LogicalResult mlir::emitc::LValueType::verify(
1321 <<
"!emitc.lvalue must wrap supported emitc type, but got " << value;
1323 if (llvm::isa<emitc::ArrayType>(value))
1324 return emitError() <<
"!emitc.lvalue cannot wrap !emitc.array type";
1333LogicalResult mlir::emitc::OpaqueType::verify(
1335 llvm::StringRef value) {
1336 if (value.empty()) {
1337 return emitError() <<
"expected non empty string in !emitc.opaque type";
1339 if (value.back() ==
'*') {
1340 return emitError() <<
"pointer not allowed as outer type with "
1341 "!emitc.opaque, use !emitc.ptr instead";
1350LogicalResult mlir::emitc::PointerType::verify(
1352 if (llvm::isa<emitc::LValueType>(value))
1353 return emitError() <<
"pointers to lvalues are not allowed";
1372 if (
auto array = llvm::dyn_cast<ArrayType>(type))
1373 return RankedTensorType::get(array.getShape(), array.getElementType());
1384 typeAttr = TypeAttr::get(type);
1392 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1395 <<
"initial value should be a integer, float, elements or opaque "
1400LogicalResult GlobalOp::verify() {
1404 if (getInitialValue().has_value()) {
1405 Attribute initValue = getInitialValue().value();
1408 if (
auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1409 auto arrayType = llvm::dyn_cast<ArrayType>(
getType());
1413 Type initType = elementsAttr.getType();
1415 if (initType != tensorType) {
1416 return emitOpError(
"initial value expected to be of type ")
1417 <<
getType() <<
", but was of type " << initType;
1419 }
else if (
auto intAttr = dyn_cast<IntegerAttr>(initValue)) {
1420 if (intAttr.getType() !=
getType()) {
1421 return emitOpError(
"initial value expected to be of type ")
1422 <<
getType() <<
", but was of type " << intAttr.getType();
1424 }
else if (
auto floatAttr = dyn_cast<FloatAttr>(initValue)) {
1425 if (floatAttr.getType() !=
getType()) {
1426 return emitOpError(
"initial value expected to be of type ")
1427 <<
getType() <<
", but was of type " << floatAttr.getType();
1429 }
else if (!isa<emitc::OpaqueAttr>(initValue)) {
1430 return emitOpError(
"initial value should be a integer, float, elements "
1431 "or opaque attribute, but got ")
1435 if (getStaticSpecifier() && getExternSpecifier()) {
1436 return emitOpError(
"cannot have both static and extern specifiers");
1452 << getName() <<
"' does not reference a valid emitc.global";
1454 Type resultType = getResult().getType();
1455 Type globalType = global.getType();
1458 if (llvm::isa<ArrayType>(globalType)) {
1459 if (globalType != resultType)
1460 return emitOpError(
"on array type expects result type ")
1461 << resultType <<
" to match type " << globalType
1462 <<
" of the global @" << getName();
1467 auto lvalueType = dyn_cast<LValueType>(resultType);
1469 return emitOpError(
"on non-array type expects result type to be an "
1470 "lvalue type for the global @")
1472 if (lvalueType.getValueType() != globalType)
1473 return emitOpError(
"on non-array type expects result inner type ")
1474 << lvalueType.getValueType() <<
" to match type " << globalType
1475 <<
" of the global @" << getName();
1490 Region ®ion = *caseRegions.emplace_back(std::make_unique<Region>());
1494 caseValues.push_back(value);
1503 for (
auto [value, region] : llvm::zip(cases.
asArrayRef(), caseRegions)) {
1505 p <<
"case " << value <<
' ';
1511 const Twine &name) {
1512 auto yield = dyn_cast<emitc::YieldOp>(region.
front().
back());
1514 return op.emitOpError(
"expected region to end with emitc.yield, but got ")
1517 if (yield.getNumOperands() != 0) {
1518 return (op.emitOpError(
"expected each region to return ")
1519 <<
"0 values, but " << name <<
" returns "
1520 << yield.getNumOperands())
1521 .attachNote(yield.getLoc())
1522 <<
"see yield operation here";
1528LogicalResult emitc::SwitchOp::verify() {
1530 return emitOpError(
"unsupported type ") << getArg().getType();
1532 if (getCases().size() != getCaseRegions().size()) {
1534 << getCaseRegions().size() <<
" case regions but "
1535 << getCases().size() <<
" case values";
1539 for (
int64_t value : getCases())
1540 if (!valueSet.insert(value).second)
1541 return emitOpError(
"has duplicate case value: ") << value;
1546 for (
auto [idx, caseRegion] : llvm::enumerate(getCaseRegions()))
1553unsigned emitc::SwitchOp::getNumCases() {
return getCases().size(); }
1555Block &emitc::SwitchOp::getDefaultBlock() {
return getDefaultRegion().
front(); }
1557Block &emitc::SwitchOp::getCaseBlock(
unsigned idx) {
1558 assert(idx < getNumCases() &&
"case index out-of-bounds");
1559 return getCaseRegions()[idx].front();
1562void SwitchOp::getSuccessorRegions(
1564 llvm::append_range(successors, getRegions());
1570 Type type = attr.getType();
1572 return attr.getInt();
1574 return attr.getSInt();
1576 return static_cast<int64_t>(attr.getUInt());
1577 return std::nullopt;
1580void SwitchOp::getEntrySuccessorRegions(
1583 FoldAdaptor adaptor(operands, *
this);
1586 auto arg = dyn_cast_or_null<IntegerAttr>(adaptor.getArg());
1588 llvm::append_range(successors, getRegions());
1595 llvm::append_range(successors, getRegions());
1601 for (
auto [caseValue, caseRegion] : llvm::zip(getCases(), getCaseRegions())) {
1602 if (caseValue == *argValue) {
1603 successors.emplace_back(&caseRegion);
1607 successors.emplace_back(&getDefaultRegion());
1610void SwitchOp::getRegionInvocationBounds(
1612 auto operandValue = llvm::dyn_cast_or_null<IntegerAttr>(operands.front());
1613 if (!operandValue) {
1620 if (!maybeIntValue) {
1626 unsigned liveIndex = getNumRegions() - 1;
1627 const auto *iteratorToInt = llvm::find(getCases(), *maybeIntValue);
1629 liveIndex = iteratorToInt != getCases().end()
1630 ? std::distance(getCases().begin(), iteratorToInt)
1633 for (
unsigned regIndex = 0, regNum = getNumRegions(); regIndex < regNum;
1635 bounds.emplace_back(0, regIndex == liveIndex);
1662 if (
auto array = llvm::dyn_cast<ArrayType>(type))
1663 return RankedTensorType::get(array.getShape(), array.getElementType());
1674 typeAttr = TypeAttr::get(type);
1682 if (!llvm::isa<ElementsAttr, IntegerAttr, FloatAttr, emitc::OpaqueAttr>(
1685 <<
"initial value should be a integer, float, elements or opaque "
1690LogicalResult FieldOp::verify() {
1695 if (!parentOp || !isa<emitc::ClassOp>(parentOp))
1696 return emitOpError(
"field must be nested within an emitc.class operation");
1698 StringAttr symName = getSymNameAttr();
1699 if (!symName || symName.getValue().empty())
1700 return emitOpError(
"field must have a non-empty symbol name");
1709LogicalResult GetFieldOp::verify() {
1710 auto parentClassOp = getOperation()->getParentOfType<emitc::ClassOp>();
1711 if (!parentClassOp.getOperation())
1712 return emitOpError(
" must be nested within an emitc.class operation");
1723 << fieldNameAttr <<
"' not found in the class";
1725 Type getFieldResultType = getResult().getType();
1726 Type fieldType = fieldOp.getType();
1728 if (fieldType != getFieldResultType)
1730 << getFieldResultType <<
" does not match field '" << fieldNameAttr
1731 <<
"' type " << fieldType;
1748LogicalResult emitc::DoOp::verify() {
1749 Block &condBlock = getConditionRegion().
front();
1753 "condition region must contain exactly two operations: "
1754 "'emitc.expression' followed by 'emitc.yield', but found ")
1758 auto exprOp = dyn_cast<emitc::ExpressionOp>(first);
1760 return emitOpError(
"expected first op in condition region to be "
1761 "'emitc.expression', but got ")
1764 if (!exprOp.getResult().getType().isInteger(1))
1765 return emitOpError(
"emitc.expression in condition region must return "
1766 "'i1', but returns ")
1767 << exprOp.getResult().getType();
1770 auto condYield = dyn_cast<emitc::YieldOp>(last);
1772 return emitOpError(
"expected last op in condition region to be "
1773 "'emitc.yield', but got ")
1776 if (condYield.getNumOperands() != 1)
1777 return emitOpError(
"expected condition region to return 1 value, but "
1779 << condYield.getNumOperands() <<
" values";
1781 if (condYield.getOperand(0) != exprOp.getResult())
1782 return emitError(
"'emitc.yield' must return result of "
1783 "'emitc.expression' from this condition region");
1787 return emitOpError(
"body region must not contain terminator");
1800 if (bodyRegion->
empty())
1810#include "mlir/Dialect/EmitC/IR/EmitCInterfaces.cpp.inc"
1812#define GET_OP_CLASSES
1813#include "mlir/Dialect/EmitC/IR/EmitC.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static std::optional< int64_t > getLowerBound(Value iv)
Gets the constant lower bound on an iv.
static std::optional< int64_t > getIntAttrValue(IntegerAttr attr)
Returns the int64_t value of an IntegerAttr regardless of whether its type is signless,...
static LogicalResult verifyInitializationAttribute(Operation *op, Attribute value)
Check that the type of the initial value is compatible with the operations result type.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static ParseResult parseEmitCGlobalOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static Type getInitializerTypeForField(Type type)
static ParseResult parseEmitCFieldOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
FailureOr< SmallVector< ReplacementItem > > parseFormatString(StringRef toParse, ArgType fmtArgs, llvm::function_ref< mlir::InFlightDiagnostic()> emitError={})
Parse a format string and return a list of its parts.
static void printEmitCGlobalOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static ParseResult parseSwitchCases(OpAsmParser &parser, DenseI64ArrayAttr &cases, SmallVectorImpl< std::unique_ptr< Region > > &caseRegions)
Parse the case regions and values.
static void printEmitCFieldOpTypeAndInitialValue(OpAsmPrinter &p, FieldOp op, TypeAttr type, Attribute initialValue)
static void printSwitchCases(OpAsmPrinter &p, Operation *op, DenseI64ArrayAttr cases, RegionRange caseRegions)
Print the case regions and values.
static Type getInitializerTypeForGlobal(Type type)
static Type getElementType(Type type)
Determine the element type of type.
static std::string diag(const llvm::Value &value)
static Type getValueType(Attribute attr)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseDimensionList(SmallVectorImpl< int64_t > &dimensions, bool allowDynamic=true, bool withTrailingX=true)=0
Parse a dimension list of a tensor or memref type.
virtual ParseResult parseOptionalGreater()=0
Parse a '>' token if present.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual OptionalParseResult parseOptionalAttribute(Attribute &result, Type type={})=0
Parse an arbitrary optional attribute of a given type and return it in result.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
auto getChecked(SMLoc loc, ParamsT &&...params)
Invoke the getChecked method of the given Attribute or Type class, using the provided location to emi...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
This base class exposes generic asm printer hooks, usable across the various derived printers.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printNewline()
Print a newline and indent the printer to the start of the current operation/attribute/type.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
OpListType & getOperations()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
BlockArgListType getArguments()
iterator_range< iterator > without_terminator()
Return an iterator range over the operation within this block excluding the terminator operation at t...
This class is a general helper class for creating context-global objects like types,...
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
IntegerType getIntegerType(unsigned width)
StringAttr getStringAttr(const Twine &bytes)
NamedAttribute getNamedAttr(StringRef name, Attribute val)
A symbol reference with a reference path containing a single element.
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
This class represents a diagnostic that is inflight and set to be reported.
This class represents upper and lower bounds on the number of times a region of a RegionBranchOpInter...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse)=0
Renumber the arguments for the specified region to the same names as the SSA values in namesToUse.
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
void printFunctionalType(Operation *op)
Print the complete type of an operation in functional form.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
This class represents a single result from folding an operation.
type_range getType() const
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperationName getName()
The name of an operation is the key identifier for it.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class implements Optional functionality for ParseResult.
bool has_value() const
Returns true if we contain a valid ParseResult value.
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isSignedInteger() const
Return true if this is a signed integer type (with the specified width).
bool isSignlessInteger() const
Return true if this is a signless integer type (with the specified width).
bool isUnsignedInteger() const
Return true if this is an unsigned integer type (with the specified width).
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
ArrayRef< T > asArrayRef() const
A named class for passing around the variadic flag.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void buildTerminatedBody(OpBuilder &builder, Location loc)
Default callback for builders of ops carrying a region.
std::variant< StringRef, Placeholder > ReplacementItem
bool isFundamentalType(mlir::Type type)
Determines whether type is a valid fundamental C++ type in EmitC.
bool isSupportedFloatType(mlir::Type type)
Determines whether type is a valid floating-point type in EmitC.
bool isSupportedEmitCType(mlir::Type type)
Determines whether type is valid in EmitC.
bool isPointerWideType(mlir::Type type)
Determines whether type is a emitc.size_t/ssize_t type.
bool isIntegerIndexOrOpaqueType(Type type)
Determines whether type is integer like, i.e.
bool isSupportedIntegerType(mlir::Type type)
Determines whether type is a valid integer type in EmitC.
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
llvm::function_ref< Fn > function_ref
UnresolvedOperand ssaName
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.