25 #include "llvm/ADT/SetOperations.h" 26 #include "llvm/ADT/SmallString.h" 27 #include "llvm/ADT/TypeSwitch.h" 28 #include "llvm/Support/raw_ostream.h" 33 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc" 36 #include "ShapeCanonicalization.inc" 40 return RankedTensorType::get({rank}, IndexType::get(ctx));
44 auto ranked = type.
dyn_cast<RankedTensorType>();
45 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
51 auto type = inputOp.getArg().getType().cast<ShapedType>();
54 llvm::append_range(shapeValues, type.getShape());
59 llvm::append_range(shapeValues, attr.getValues<int64_t>());
66 return llvm::any_of(operandTypes, [](
Type ty) {
67 return ty.
isa<SizeType, ShapeType, ValueShapeType>();
75 if (!resultTy.isa<SizeType>())
77 <<
"if at least one of the operands can hold error values then " 78 "the result must be of type `size` to propagate them";
87 if (!resultTy.isa<ShapeType>())
89 <<
"if at least one of the operands can hold error values then " 90 "the result must be of type `shape` to propagate them";
95 template <
typename... Ty>
97 return typeRange.size() == 1 && typeRange.front().isa<Ty...>();
100 template <
typename... Ty,
typename... ranges>
131 void ShapeDialect::initialize() {
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 137 #define GET_TYPEDEF_LIST 138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" 140 addInterfaces<ShapeInlinerInterface>();
144 allowUnknownOperations();
151 return builder.
create<ConstShapeOp>(loc, type,
153 if (type.isa<SizeType>())
154 return builder.
create<ConstSizeOp>(loc, type, value.
cast<IntegerAttr>());
155 if (type.isa<WitnessType>())
157 if (arith::ConstantOp::isBuildableWith(value, type))
158 return builder.
create<arith::ConstantOp>(loc, type,
value);
165 if (attribute.
getName() ==
"shape.lib") {
168 "shape.lib attribute may only be on op implementing SymbolTable");
173 return op->
emitError(
"shape function library ")
174 << symbolRef <<
" not found";
175 return isa<shape::FunctionLibraryOp>(symbol)
178 << symbolRef <<
" required to be shape function library";
185 for (
auto it : arr) {
186 if (!it.isa<SymbolRefAttr>())
188 "only SymbolRefAttr allowed in shape.lib attribute array");
190 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
194 << it <<
" does not refer to FunctionLibraryOp";
195 for (
auto mapping : shapeFnLib.getMapping()) {
196 if (!key.insert(mapping.getName()).second) {
197 return op->
emitError(
"only one op to shape mapping allowed, found " 199 << mapping.getName() <<
"`";
206 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs " 207 "allowed as shape.lib attribute");
221 return operands.back();
257 bool yieldsResults = !getResults().empty();
259 p <<
" " << getWitness();
261 p <<
" -> (" << getResultTypes() <<
")";
276 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
277 if (!witness || !witness.getPassingAttr())
280 AssumingOp::inlineRegionIntoParent(op, rewriter);
285 struct AssumingOpRemoveUnusedResults :
public OpRewritePattern<AssumingOp> {
290 Block *body = op.getBody();
291 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
295 Value opResult, yieldOperand;
296 for (
auto it : llvm::zip(op.getResults(), yieldOp.getOperands())) {
297 std::tie(opResult, yieldOperand) = it;
298 if (!opResult.
getUses().empty()) {
299 newYieldOperands.push_back(yieldOperand);
304 if (newYieldOperands.size() == yieldOp->getNumOperands())
313 auto newOp = rewriter.
create<AssumingOp>(
314 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
315 newOp.getDoRegion().takeBody(op.getDoRegion());
319 auto src = newOp.getResults().begin();
320 for (
auto it : op.getResults()) {
321 if (it.getUses().empty())
322 replacementValues.push_back(
nullptr);
324 replacementValues.push_back(*src++);
326 rewriter.
replaceOp(op, replacementValues);
334 patterns.
add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
338 void AssumingOp::getSuccessorRegions(
352 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
355 auto *assumingBlock = op.getBody();
357 auto *blockAfterAssuming =
358 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
361 auto &yieldOp = assumingBlock->
back();
368 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
369 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
372 void AssumingOp::build(
388 for (
Value v : yieldValues)
389 assumingTypes.push_back(v.getType());
401 if (operands[0].getType().isa<SizeType>() ||
402 operands[1].getType().isa<SizeType>())
403 inferredReturnTypes.assign({SizeType::get(context)});
405 inferredReturnTypes.assign({IndexType::get(context)});
411 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
419 return constFoldBinaryOp<IntegerAttr>(
420 operands, [](APInt a,
const APInt &b) {
return std::move(a) + b; });
446 for (
Value operand : op.getInputs()) {
447 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
448 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
450 operands.push_back(operand);
454 if (operands.size() == op.getNumOperands())
484 struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
491 for (
Value operand : op.getInputs()) {
494 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
498 operands.insert(broadcastable);
502 if (operands.size() <= 1)
507 for (
auto cstr : operands) {
509 shapes.emplace_back(cstr, std::move(shapesSet));
513 llvm::sort(shapes, [](
auto a,
auto b) {
514 return a.first.getNumOperands() > b.first.getNumOperands();
523 for (
unsigned i = 0; i < shapes.size(); ++i) {
524 auto isSubset = [&](
auto pair) {
525 return llvm::set_is_subset(pair.second, shapes[i].second);
529 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
530 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
531 markedForErase.push_back(it0->first);
532 shapes.erase(it, shapes.end());
536 if (markedForErase.empty())
541 for (
auto &shape : shapes)
542 uniqueConstraints.push_back(shape.first.getResult());
548 for (
auto &op : markedForErase)
556 struct AssumingAllToCstrEqCanonicalization
563 for (
Value w : op.getInputs()) {
564 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
567 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](
Value s) {
568 return llvm::is_contained(shapes, s);
570 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
572 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
579 template <
typename OpTy>
589 if (unique.size() < op.getNumOperands()) {
591 unique.takeVector(), op->getAttrs());
603 .
add<MergeAssumingAllOps, AssumingAllOneOp,
604 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
605 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
611 for (
int idx = operands.size() - 1; idx >= 0; idx--) {
619 getOperation()->eraseOperand(idx);
631 if (getNumOperands() == 0)
632 return emitOpError(
"no operands specified");
642 if (getShapes().size() == 1) {
644 if (getShapes().front().getType() != getType())
646 return getShapes().front();
650 if (getShapes().size() > 2)
653 if (!operands[0] || !operands[1])
655 auto lhsShape = llvm::to_vector<6>(
657 auto rhsShape = llvm::to_vector<6>(
675 template <
typename OpTy>
681 auto isPotentiallyNonEmptyShape = [](
Value shape) {
682 if (
auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
683 if (extentTensorTy.getDimSize(0) == 0)
686 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
687 if (constShape.getShape().empty())
692 auto newOperands = llvm::to_vector<8>(
693 llvm::make_filter_range(op->getOperands(), isPotentiallyNonEmptyShape));
696 if (newOperands.size() < op.getNumOperands()) {
706 struct BroadcastForwardSingleOperandPattern
712 if (op.getNumOperands() != 1)
714 Value replacement = op.getShapes().front();
717 if (replacement.
getType() != op.getType()) {
718 auto loc = op.getLoc();
719 if (op.getType().isa<ShapeType>()) {
720 replacement = rewriter.
create<FromExtentTensorOp>(loc, replacement);
722 assert(!op.getType().isa<ShapeType>() &&
724 "expect extent tensor cast");
726 rewriter.
create<tensor::CastOp>(loc, op.getType(), replacement);
735 struct BroadcastFoldConstantOperandsPattern
743 for (
Value shape : op.getShapes()) {
744 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
748 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
749 newFoldedConstantShape)) {
750 foldedConstantShape = newFoldedConstantShape;
754 newShapeOperands.push_back(shape);
758 if (op.getNumOperands() - newShapeOperands.size() < 2)
761 auto foldedConstantOperandsTy = RankedTensorType::get(
762 {
static_cast<int64_t
>(foldedConstantShape.size())},
764 newShapeOperands.push_back(rewriter.
create<ConstShapeOp>(
765 op.getLoc(), foldedConstantOperandsTy,
773 template <
typename OpTy>
774 struct CanonicalizeCastExtentTensorOperandsPattern
781 bool anyChange =
false;
782 auto canonicalizeOperand = [&](
Value operand) {
783 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
785 bool isInformationLoosingCast =
786 castOp.getType().cast<RankedTensorType>().isDynamicDim(0);
787 if (isInformationLoosingCast) {
789 return castOp.getSource();
794 auto newOperands = llvm::to_vector<8>(
795 llvm::map_range(op.getOperands(), canonicalizeOperand));
805 struct BroadcastConcretizeResultTypePattern
812 auto resultTy = op.getType().dyn_cast<RankedTensorType>();
813 if (!resultTy || !resultTy.isDynamicDim(0))
818 for (
Value shape : op.getShapes()) {
819 if (
auto extentTensorTy = shape.getType().dyn_cast<RankedTensorType>()) {
822 if (extentTensorTy.isDynamicDim(0))
824 maxRank =
std::max(maxRank, extentTensorTy.getDimSize(0));
828 auto newOp = rewriter.
create<BroadcastOp>(
839 patterns.
add<BroadcastConcretizeResultTypePattern,
840 BroadcastFoldConstantOperandsPattern,
841 BroadcastForwardSingleOperandPattern,
842 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
843 RemoveDuplicateOperandsPattern<BroadcastOp>,
844 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
852 if (!operands[0] || !operands[1])
854 auto lhsShape = llvm::to_vector<6>(
856 auto rhsShape = llvm::to_vector<6>(
859 resultShape.append(lhsShape.begin(), lhsShape.end());
860 resultShape.append(rhsShape.begin(), rhsShape.end());
873 interleaveComma(
getShape().getValues<int64_t>(), p);
888 auto extentsArray = extentsRaw.
dyn_cast<ArrayAttr>();
893 IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
896 ints.push_back(attr.getInt());
903 result.
types.push_back(resultTy);
911 patterns.
add<TensorCastConstShape>(context);
922 inferredReturnTypes.assign({RankedTensorType::get(
923 {
static_cast<int64_t
>(shape.size())}, b.
getIndexType())});
927 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
929 if (l.size() != 1 || r.size() != 1)
932 Type lhs = l.front();
933 Type rhs = r.front();
935 if (lhs.
isa<ShapeType>() || rhs.
isa<ShapeType>())
945 void CstrBroadcastableOp::getCanonicalizationPatterns(
950 patterns.
add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
951 CstrBroadcastableEqOps,
952 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
953 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
959 bool nonScalarSeen =
false;
964 nonScalarSeen =
true;
977 for (
const auto &operand : operands) {
980 extents.push_back(llvm::to_vector<6>(
991 for (
auto shapeValue : getShapes()) {
992 extents.emplace_back();
1007 if (getNumOperands() < 2)
1008 return emitOpError(
"required at least 2 input shapes");
1019 patterns.
add<CstrEqEqOps>(context);
1023 if (llvm::all_of(operands,
1024 [&](
Attribute a) {
return a && a == operands[0]; }))
1044 void ConstSizeOp::getAsmResultNames(
1047 llvm::raw_svector_ostream os(buffer);
1048 os <<
"c" << getValue();
1049 setNameFn(getResult(), os.str());
1057 return getPassingAttr();
1073 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1076 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1082 APInt quotient, remainder;
1083 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1084 if (quotient.isNegative() && !remainder.isNullValue()) {
1088 Type indexTy = IndexType::get(getContext());
1089 return IntegerAttr::get(indexTy, quotient);
1096 if (operands[0].getType().isa<SizeType>() ||
1097 operands[1].getType().isa<SizeType>())
1098 inferredReturnTypes.assign({SizeType::get(context)});
1100 inferredReturnTypes.assign({IndexType::get(context)});
1106 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1116 bool allSame =
true;
1117 if (!operands.empty() && !operands[0])
1119 for (
Attribute operand : operands.drop_front(1)) {
1122 allSame = allSame && operand == operands[0];
1141 patterns.
add<SizeToIndexToSizeCanonicalization>(context);
1149 if (llvm::any_of(operands, [](
Attribute a) {
return !a; }))
1152 for (
auto attr : operands)
1153 extents.push_back(attr.cast<IntegerAttr>().getInt());
1154 Builder builder(getContext());
1168 FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1169 auto attr = getMapping()
1171 .dyn_cast_or_null<FlatSymbolRefAttr>();
1174 return lookupSymbol<FuncOp>(attr);
1180 StringAttr nameAttr;
1195 DictionaryAttr mappingAttr;
1220 auto buildFuncType =
1223 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1226 parser, result,
false, buildFuncType);
1238 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1239 return constSizeOp.getValue().getLimitedValue();
1240 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1241 return constantOp.getValue().cast<IntegerAttr>().getInt();
1250 if (!dim.hasValue())
1252 if (dim.getValue() >= elements.getNumElements())
1254 return elements.getValues<
Attribute>()[(uint64_t)dim.getValue()];
1262 Value dim = builder.
create<ConstSizeOp>(loc, dimAttr);
1263 build(builder, result, builder.
getType<SizeType>(), shape, dim);
1267 build(builder, result, builder.
getIndexType(), shape, dim);
1275 inferredReturnTypes.assign({IndexType::get(context)});
1279 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1282 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1291 void IsBroadcastableOp::getCanonicalizationPatterns(
RewritePatternSet &patterns,
1293 patterns.
add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1298 if (operands.size() < 2) {
1313 inferredReturnTypes.assign({operands[0].
getType()});
1318 if (l.size() != 1 || r.size() != 1)
1323 Type lhs = l.front();
1324 Type rhs = r.front();
1329 if (lhs.
isa<SizeType>() || lhs.
isa<ShapeType>())
1345 int64_t rank = shape.getNumElements();
1346 Builder builder(getContext());
1365 struct RankShapeOfCanonicalizationPattern
1371 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1374 auto rankedTensorType =
1375 shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1376 if (!rankedTensorType)
1378 int64_t rank = rankedTensorType.getRank();
1379 if (op.getType().isa<IndexType>()) {
1382 }
else if (op.getType().isa<shape::SizeType>()) {
1394 patterns.
add<RankShapeOfCanonicalizationPattern>(context);
1401 if (operands[0].getType().isa<ShapeType>())
1402 inferredReturnTypes.assign({SizeType::get(context)});
1404 inferredReturnTypes.assign({IndexType::get(context)});
1410 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1426 APInt product(64, 1);
1429 Builder builder(getContext());
1430 return builder.
getIndexAttr(product.getLimitedValue());
1437 if (operands[0].getType().isa<ShapeType>())
1438 inferredReturnTypes.assign({SizeType::get(context)});
1440 inferredReturnTypes.assign({IndexType::get(context)});
1444 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1447 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1460 if (getLhs() == getRhs())
1469 if (operands[0].getType() == operands[1].getType())
1470 inferredReturnTypes.assign({operands[0].
getType()});
1472 inferredReturnTypes.assign({SizeType::get(context)});
1477 if (l.size() != 1 || r.size() != 1)
1479 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1481 if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1492 if (getLhs() == getRhs())
1501 if (operands[0].getType() == operands[1].getType())
1502 inferredReturnTypes.assign({operands[0].
getType()});
1504 inferredReturnTypes.assign({SizeType::get(context)});
1509 if (l.size() != 1 || r.size() != 1)
1511 if (l.front().isa<ShapeType>() && r.front().isa<ShapeType>())
1513 if (l.front().isa<SizeType>() && r.front().isa<SizeType>())
1523 auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>();
1526 auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>();
1529 APInt folded = lhs.getValue() * rhs.getValue();
1530 Type indexTy = IndexType::get(getContext());
1531 return IntegerAttr::get(indexTy, folded);
1538 if (operands[0].getType().isa<SizeType>() ||
1539 operands[1].getType().isa<SizeType>())
1540 inferredReturnTypes.assign({SizeType::get(context)});
1542 inferredReturnTypes.assign({IndexType::get(context)});
1548 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1558 auto type = getOperand().getType().
dyn_cast<ShapedType>();
1559 if (!type || !type.hasStaticShape())
1561 Builder builder(getContext());
1571 if (!op.getArg().getType().isa<ShapedType>())
1573 if (op.getType().isa<ShapedType>())
1596 auto ty = op.getType().dyn_cast<RankedTensorType>();
1597 if (!ty || ty.getRank() != 1)
1600 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1605 auto argTy = shapeOfOp.getArg().getType().dyn_cast<RankedTensorType>();
1606 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1617 patterns.
add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1618 ExtractFromShapeOfExtentTensor>(context);
1625 if (operands[0].getType().isa<ValueShapeType>())
1626 inferredReturnTypes.assign({ShapeType::get(context)});
1628 auto shapedTy = operands[0].
getType().cast<ShapedType>();
1630 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamicSize;
1631 Type indexTy = IndexType::get(context);
1632 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1633 inferredReturnTypes.assign({extentTensorTy});
1639 if (l.size() != 1 || r.size() != 1)
1644 Type lhs = l.front();
1645 Type rhs = r.front();
1647 if (!lhs.
isa<ShapeType, ShapedType>() || !rhs.
isa<ShapeType, ShapedType>())
1650 if (lhs.
isa<ShapeType>() || rhs.
isa<ShapeType>())
1677 patterns.
add<IndexToSizeToIndexCanonicalization>(context);
1681 if (inputs.size() != 1 || outputs.size() != 1)
1683 return inputs[0].isa<IndexType, SizeType>() && outputs[0].isa<IndexType>();
1691 auto *parentOp = (*this)->getParentOp();
1692 auto results = parentOp->getResults();
1693 auto operands = getOperands();
1695 if (parentOp->getNumResults() != getNumOperands())
1696 return emitOpError() <<
"number of operands does not match number of " 1697 "results of its parent";
1698 for (
auto e : llvm::zip(results, operands))
1699 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1700 return emitOpError() <<
"types mismatch between yield op and its parent";
1711 if (!operands[0] || !operands[1])
1713 auto shapeVec = llvm::to_vector<6>(
1715 auto shape = llvm::makeArrayRef(shapeVec);
1716 auto splitPoint = operands[1].cast<IntegerAttr>().getInt();
1719 int64_t rank = shape.size();
1720 if (-rank > splitPoint || splitPoint > rank)
1723 splitPoint += shape.size();
1724 Builder builder(operands[0].getContext());
1737 Builder builder(getContext());
1738 auto shape = llvm::to_vector<6>(
1740 auto type = RankedTensorType::get({
static_cast<int64_t
>(shape.size())},
1746 if (inputs.size() != 1 || outputs.size() != 1)
1748 if (
auto inputTensor = inputs[0].dyn_cast<RankedTensorType>()) {
1749 if (!inputTensor.getElementType().isa<IndexType>() ||
1750 inputTensor.getRank() != 1)
1752 }
else if (!inputs[0].isa<ShapeType>()) {
1776 elementType = tensorType.getElementType();
1778 elementType = SizeType::get(builder.
getContext());
1781 for (
Value initVal : initVals) {
1782 bodyBlock.
addArgument(initVal.getType(), initVal.getLoc());
1783 result.
addTypes(initVal.getType());
1792 auto blockArgsCount = getInitVals().size() + 2;
1794 return emitOpError() <<
"ReduceOp body is expected to have " 1795 << blockArgsCount <<
" arguments";
1800 "argument 0 of ReduceOp body is expected to be of IndexType");
1806 if (
getShape().getType().isa<ShapeType>()) {
1807 if (!extentTy.
isa<SizeType>())
1808 return emitOpError(
"argument 1 of ReduceOp body is expected to be of " 1809 "SizeType if the ReduceOp operates on a ShapeType");
1811 if (!extentTy.
isa<IndexType>())
1813 "argument 1 of ReduceOp body is expected to be of IndexType if the " 1814 "ReduceOp operates on an extent tensor");
1818 if (block.
getArgument(type.index() + 2).getType() != type.value().getType())
1819 return emitOpError() <<
"type mismatch between argument " 1821 <<
" of ReduceOp body and initial value " 1829 Type shapeOrExtentTensorType;
1837 auto initVals = llvm::makeArrayRef(operands).drop_front();
1838 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
1857 p <<
'(' <<
getShape() <<
", " << getInitVals()
1865 #define GET_OP_CLASSES 1866 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" 1868 #define GET_TYPEDEF_CLASSES 1869 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc" virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, BlockAndValueMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
TODO: Remove this file when SCCP and integer range analysis have been ported to the new framework...
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
This class contains a list of basic blocks and a link to the parent operation it is attached to...
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
MLIRContext * getContext() const
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Operation is a basic unit of execution within MLIR.
Attribute getValue() const
Return the value of the attribute.
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
operand_range getOperands()
Returns an iterator on the underlying Value's.
virtual void printType(Type type)
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Block represents an ordered list of Operations.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents a single result from folding an operation.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block...
void push_back(Block *block)
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value...
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
operand_type_range getOperandTypes()
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value...
This is the representation of an operand reference.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it...
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
NamedAttribute getNamedAttr(StringRef name, Attribute val)
A named class for passing around the variadic flag.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent"...
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
BlockArgument getArgument(unsigned i)
static constexpr const bool value
ParseResult parseSymbolName(StringAttr &result, StringRef attrName, NamedAttrList &attrs)
Parse an -identifier and store it (without the '@' symbol) in a string attribute named 'attrName'...
SmallVector< Value, 4 > operands
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
NamedAttribute represents a combination of a name and an Attribute value.
LogicalResult emitOptionalError(Optional< Location > loc, Args &&... args)
Overloads of the above emission functions that take an optionally null location.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
bool getValue() const
Return the boolean value of this attribute.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an efficient way to signal success or failure.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
void addOperands(ValueRange newOperands)
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
unsigned getNumArguments()
Attributes are known-constant values of operations.
bool isExtentTensorType(Type)
Special case of IntegerAttr to represent boolean integers, i.e., signless i1 integers.
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
DialectInlinerInterface(Dialect *dialect)
StringAttr getName() const
Return the name of the attribute.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
This is the interface that must be implemented by the dialects of operations to be inlined...
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
A trait used to provide symbol table functionalities to a region operation.
static bool isErrorPropagationPossible(TypeRange operandTypes)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic)
Printer implementation for function-like operations.
This class provides an abstraction over the various different ranges of value types.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
void addTypes(ArrayRef< Type > newTypes)
ParseResult parseKeyword(StringRef keyword, const Twine &msg="")
Parse a given keyword.
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
This represents an operation in an abstracted form, suitable for use with the builder APIs...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true...
Parens surrounding zero or more operands.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
ParseResult resolveOperands(ArrayRef< UnresolvedOperand > operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Location getLoc() const
Return the location of this value.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Operation * getTerminator()
Get the terminator operation of this block.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Type front()
Return first type in the range.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values...
RAII guard to reset the insertion point of the builder when destroyed.
This class represents a successor of a region.
Region * addRegion()
Create a region that should be attached to the operation.
static LogicalResult verifySizeOrIndexOp(Operation *op)
This class is a general helper class for creating context-global objects like types, attributes, and affine expressions.
Type getType() const
Return the type of this value.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamicSize)
Alias type for extent tensors.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Specialization of arith.constant op that returns an integer of index type.
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
MLIRContext is the top-level object for a collection of MLIR operations.
Type getElementType() const
Returns the element type of this tensor type.
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
This class provides an abstraction over the different types of ranges over Regions.
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs, on this operation and any nested operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, FuncTypeBuilder funcTypeBuilder)
Parser implementation for function-like operations.
unsigned getNumResults()
Return the number of results held by this operation.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers...
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
static BoolAttr get(MLIRContext *context, bool value)
OperationName getName()
The name of an operation is the key identifier for it.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
This class represents success/failure for parsing-like operations that find it important to chain tog...
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void mergeBlocks(Block *source, Block *dest, ValueRange argValues=llvm::None)
Merge the operations of block 'source' into the end of block 'dest'.
This class helps build Operations.
This class provides an abstraction over the different types of ranges over Values.
IntegerAttr getIndexAttr(int64_t value)
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter...
result_type_range getResultTypes()
StringAttr getStringAttr(const Twine &bytes)
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type. ...
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
An attribute that represents a reference to a dense integer vector or tensor object.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
SmallVector< Type, 4 > types
Types of the results of this operation.