26 #include "llvm/ADT/SetOperations.h"
27 #include "llvm/ADT/SmallString.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
37 #include "ShapeCanonicalization.inc"
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
55 llvm::append_range(shapeValues, type.getShape());
60 llvm::append_range(shapeValues, attr.getValues<int64_t>());
67 return llvm::any_of(operandTypes, [](
Type ty) {
68 return llvm::isa<SizeType, ShapeType, ValueShapeType>(ty);
76 if (!llvm::isa<SizeType>(resultTy))
78 <<
"if at least one of the operands can hold error values then "
79 "the result must be of type `size` to propagate them";
88 if (!llvm::isa<ShapeType>(resultTy))
90 <<
"if at least one of the operands can hold error values then "
91 "the result must be of type `shape` to propagate them";
96 template <
typename... Ty>
98 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
101 template <
typename... Ty,
typename... ranges>
132 void ShapeDialect::initialize() {
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
141 addInterfaces<ShapeInlinerInterface>();
145 allowUnknownOperations();
151 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
152 return builder.
create<ub::PoisonOp>(loc, type, poison);
155 return builder.
create<ConstShapeOp>(
156 loc, type, llvm::cast<DenseIntElementsAttr>(value));
157 if (llvm::isa<SizeType>(type))
158 return builder.
create<ConstSizeOp>(loc, type,
159 llvm::cast<IntegerAttr>(value));
160 if (llvm::isa<WitnessType>(type))
161 return builder.
create<ConstWitnessOp>(loc, type,
162 llvm::cast<BoolAttr>(value));
164 return arith::ConstantOp::materialize(builder, value, type, loc);
170 if (attribute.
getName() ==
"shape.lib") {
173 "shape.lib attribute may only be on op implementing SymbolTable");
175 if (
auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.
getValue())) {
178 return op->
emitError(
"shape function library ")
179 << symbolRef <<
" not found";
180 return isa<shape::FunctionLibraryOp>(symbol)
183 << symbolRef <<
" required to be shape function library";
186 if (
auto arr = llvm::dyn_cast<ArrayAttr>(attribute.
getValue())) {
190 for (
auto it : arr) {
191 if (!llvm::isa<SymbolRefAttr>(it))
193 "only SymbolRefAttr allowed in shape.lib attribute array");
195 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
199 << it <<
" does not refer to FunctionLibraryOp";
200 for (
auto mapping : shapeFnLib.getMapping()) {
201 if (!key.insert(mapping.getName()).second) {
202 return op->
emitError(
"only one op to shape mapping allowed, found "
204 << mapping.getName() <<
"`";
211 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs "
212 "allowed as shape.lib attribute");
225 if (adaptor.getInputs().back())
226 return adaptor.getInputs().back();
262 bool yieldsResults = !getResults().empty();
264 p <<
" " << getWitness();
266 p <<
" -> (" << getResultTypes() <<
")";
281 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
282 if (!witness || !witness.getPassingAttr())
285 AssumingOp::inlineRegionIntoParent(op, rewriter);
290 struct AssumingOpRemoveUnusedResults :
public OpRewritePattern<AssumingOp> {
295 Block *body = op.getBody();
296 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
300 for (
auto [opResult, yieldOperand] :
301 llvm::zip(op.
getResults(), yieldOp.getOperands())) {
302 if (!opResult.getUses().empty()) {
303 newYieldOperands.push_back(yieldOperand);
308 if (newYieldOperands.size() == yieldOp->getNumOperands())
317 auto newOp = rewriter.
create<AssumingOp>(
318 op.
getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
319 newOp.getDoRegion().takeBody(op.getDoRegion());
323 auto src = newOp.getResults().begin();
325 if (it.getUses().empty())
326 replacementValues.push_back(
nullptr);
328 replacementValues.push_back(*src++);
330 rewriter.
replaceOp(op, replacementValues);
338 patterns.
add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
342 void AssumingOp::getSuccessorRegions(
355 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
358 auto *assumingBlock = op.getBody();
360 auto *blockAfterAssuming =
361 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
364 auto &yieldOp = assumingBlock->
back();
366 rewriter.
replaceOp(op, yieldOp.getOperands());
371 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
372 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
375 void AssumingOp::build(
391 for (
Value v : yieldValues)
392 assumingTypes.push_back(v.getType());
401 MLIRContext *context, std::optional<Location> location,
403 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404 llvm::isa<SizeType>(adaptor.getRhs().getType()))
413 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
421 return constFoldBinaryOp<IntegerAttr>(
422 adaptor.getOperands(),
423 [](APInt a,
const APInt &b) { return std::move(a) + b; });
449 for (
Value operand : op.getInputs()) {
450 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
453 operands.push_back(operand);
487 struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
494 for (
Value operand : op.getInputs()) {
497 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
501 operands.insert(broadcastable);
505 if (operands.size() <= 1)
510 for (
auto cstr : operands) {
512 shapes.emplace_back(cstr, std::move(shapesSet));
516 llvm::sort(shapes, [](
auto a,
auto b) {
517 return a.first.getNumOperands() > b.first.getNumOperands();
526 for (
unsigned i = 0; i < shapes.size(); ++i) {
527 auto isSubset = [&](
auto pair) {
528 return llvm::set_is_subset(pair.second, shapes[i].second);
532 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
534 markedForErase.push_back(it0->first);
535 shapes.erase(it, shapes.end());
539 if (markedForErase.empty())
544 for (
auto &shape : shapes)
545 uniqueConstraints.push_back(shape.first.getResult());
551 for (
auto &op : markedForErase)
559 struct AssumingAllToCstrEqCanonicalization
566 for (
Value w : op.getInputs()) {
567 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
570 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](
Value s) {
571 return llvm::is_contained(shapes, s);
573 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
575 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
582 template <
typename OpTy>
594 unique.takeVector(), op->
getAttrs());
606 .
add<MergeAssumingAllOps, AssumingAllOneOp,
607 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
614 for (
int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
622 getOperation()->eraseOperand(idx);
625 if (!llvm::cast<BoolAttr>(a).getValue())
634 if (getNumOperands() == 0)
635 return emitOpError(
"no operands specified");
645 if (getShapes().size() == 1) {
647 if (getShapes().front().getType() != getType())
649 return getShapes().front();
653 if (getShapes().size() > 2)
656 if (!adaptor.getShapes()[0] || !adaptor.getShapes()[1])
658 auto lhsShape = llvm::to_vector<6>(
659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[0])
660 .getValues<int64_t>());
661 auto rhsShape = llvm::to_vector<6>(
662 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes()[1])
663 .getValues<int64_t>());
680 template <
typename OpTy>
686 auto isPotentiallyNonEmptyShape = [](
Value shape) {
687 if (
auto extentTensorTy =
688 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
689 if (extentTensorTy.getDimSize(0) == 0)
692 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
693 if (constShape.getShape().empty())
698 auto newOperands = llvm::to_vector<8>(
699 llvm::make_filter_range(op->
getOperands(), isPotentiallyNonEmptyShape));
712 struct BroadcastForwardSingleOperandPattern
720 Value replacement = op.getShapes().front();
723 if (replacement.
getType() != op.getType()) {
725 if (llvm::isa<ShapeType>(op.getType())) {
726 replacement = rewriter.
create<FromExtentTensorOp>(loc, replacement);
728 assert(!llvm::isa<ShapeType>(op.getType()) &&
729 !llvm::isa<ShapeType>(replacement.
getType()) &&
730 "expect extent tensor cast");
732 rewriter.
create<tensor::CastOp>(loc, op.getType(), replacement);
741 struct BroadcastFoldConstantOperandsPattern
749 for (
Value shape : op.getShapes()) {
750 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
754 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
755 newFoldedConstantShape)) {
756 foldedConstantShape = newFoldedConstantShape;
760 newShapeOperands.push_back(shape);
768 {
static_cast<int64_t
>(foldedConstantShape.size())},
770 newShapeOperands.push_back(rewriter.
create<ConstShapeOp>(
771 op.
getLoc(), foldedConstantOperandsTy,
779 template <
typename OpTy>
780 struct CanonicalizeCastExtentTensorOperandsPattern
787 bool anyChange =
false;
788 auto canonicalizeOperand = [&](
Value operand) ->
Value {
789 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
791 bool isInformationLoosingCast =
792 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
793 if (isInformationLoosingCast) {
795 return castOp.getSource();
800 auto newOperands = llvm::to_vector<8>(
801 llvm::map_range(op.
getOperands(), canonicalizeOperand));
811 struct BroadcastConcretizeResultTypePattern
818 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
819 if (!resultTy || !resultTy.isDynamicDim(0))
824 for (
Value shape : op.getShapes()) {
825 if (
auto extentTensorTy =
826 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
829 if (extentTensorTy.isDynamicDim(0))
831 maxRank =
std::max(maxRank, extentTensorTy.getDimSize(0));
835 auto newOp = rewriter.
create<BroadcastOp>(
846 patterns.
add<BroadcastConcretizeResultTypePattern,
847 BroadcastFoldConstantOperandsPattern,
848 BroadcastForwardSingleOperandPattern,
849 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
850 RemoveDuplicateOperandsPattern<BroadcastOp>,
851 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
859 if (!adaptor.getLhs() || !adaptor.getRhs())
861 auto lhsShape = llvm::to_vector<6>(
862 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
863 auto rhsShape = llvm::to_vector<6>(
864 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
866 resultShape.append(lhsShape.begin(), lhsShape.end());
867 resultShape.append(rhsShape.begin(), rhsShape.end());
880 interleaveComma(
getShape().getValues<int64_t>(), p);
895 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
900 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
903 ints.push_back(attr.getInt());
910 result.
types.push_back(resultTy);
914 OpFoldResult ConstShapeOp::fold(FoldAdaptor) {
return getShapeAttr(); }
918 patterns.
add<TensorCastConstShape>(context);
922 MLIRContext *context, std::optional<Location> location,
925 const Properties prop = adaptor.getProperties();
927 {
static_cast<int64_t
>(prop.shape.size())}, b.getIndexType())});
931 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
933 if (l.size() != 1 || r.size() != 1)
936 Type lhs = l.front();
937 Type rhs = r.front();
939 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
949 void CstrBroadcastableOp::getCanonicalizationPatterns(
954 patterns.
add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
955 CstrBroadcastableEqOps,
956 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
957 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
963 bool nonScalarSeen =
false;
965 if (!a || llvm::cast<DenseIntElementsAttr>(a).
getNumElements() != 0) {
968 nonScalarSeen =
true;
974 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
981 for (
const auto &operand : adaptor.getShapes()) {
984 extents.push_back(llvm::to_vector<6>(
985 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
995 for (
auto shapeValue : getShapes()) {
996 extents.emplace_back();
1011 if (getNumOperands() < 2)
1012 return emitOpError(
"required at least 2 input shapes");
1023 patterns.
add<CstrEqEqOps>(context);
1027 if (llvm::all_of(adaptor.getShapes(), [&](
Attribute a) {
1028 return a && a == adaptor.getShapes().front();
1047 OpFoldResult ConstSizeOp::fold(FoldAdaptor) {
return getValueAttr(); }
1049 void ConstSizeOp::getAsmResultNames(
1052 llvm::raw_svector_ostream os(buffer);
1053 os <<
"c" << getValue();
1054 setNameFn(getResult(), os.str());
1061 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {
return getPassingAttr(); }
1067 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1068 return adaptor.getPred();
1075 std::optional<int64_t> DimOp::getConstantIndex() {
1076 if (
auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1077 return constSizeOp.getValue().getLimitedValue();
1078 if (
auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1079 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1080 return std::nullopt;
1084 Type valType = getValue().getType();
1085 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1086 if (!valShapedType || !valShapedType.hasRank())
1088 std::optional<int64_t> index = getConstantIndex();
1089 if (!index.has_value())
1091 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1093 auto extent = valShapedType.getDimSize(*index);
1094 if (ShapedType::isDynamic(extent))
1100 MLIRContext *context, std::optional<Location> location,
1102 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1107 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1115 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1118 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1124 APInt quotient, remainder;
1125 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1126 if (quotient.isNegative() && !remainder.isZero()) {
1135 MLIRContext *context, std::optional<Location> location,
1137 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1138 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1147 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1157 bool allSame =
true;
1158 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1160 for (
Attribute operand : adaptor.getShapes().drop_front()) {
1163 allSame = allSame && operand == adaptor.getShapes().front();
1172 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1182 patterns.
add<SizeToIndexToSizeCanonicalization>(context);
1189 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1190 if (llvm::any_of(adaptor.getExtents(), [](
Attribute a) { return !a; }))
1193 for (
auto attr : adaptor.getExtents())
1194 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1209 FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1210 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1214 return lookupSymbol<FuncOp>(attr);
1220 StringAttr nameAttr;
1235 DictionaryAttr mappingAttr;
1247 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(),
"mapping"});
1259 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1263 FuncOp::build(builder, state, name, type, attrs);
1266 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1271 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1274 FuncOp func = create(location, name, type, attrs);
1275 func.setAllArgAttrs(argAttrs);
1282 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1284 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1286 state.attributes.append(attrs.begin(), attrs.end());
1289 if (argAttrs.empty())
1291 assert(type.getNumInputs() == argAttrs.size());
1293 builder, state, argAttrs, std::nullopt,
1294 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1298 auto buildFuncType =
1301 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1304 parser, result,
false,
1305 getFunctionTypeAttrName(result.
name), buildFuncType,
1306 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
1311 p, *
this,
false, getFunctionTypeAttrName(),
1312 getArgAttrsAttrName(), getResAttrsAttrName());
1319 std::optional<int64_t> GetExtentOp::getConstantDim() {
1320 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1321 return constSizeOp.getValue().getLimitedValue();
1322 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1323 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1324 return std::nullopt;
1328 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1331 std::optional<int64_t> dim = getConstantDim();
1332 if (!dim.has_value())
1334 if (dim.value() >= elements.getNumElements())
1336 return elements.getValues<
Attribute>()[(uint64_t)dim.value()];
1343 if (llvm::isa<ShapeType>(shape.
getType())) {
1344 Value dim = builder.
create<ConstSizeOp>(loc, dimAttr);
1345 build(builder, result, builder.
getType<SizeType>(), shape, dim);
1349 build(builder, result, builder.
getIndexType(), shape, dim);
1354 MLIRContext *context, std::optional<Location> location,
1360 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1363 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1372 void IsBroadcastableOp::getCanonicalizationPatterns(
RewritePatternSet &patterns,
1374 patterns.
add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1377 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1379 if (adaptor.getShapes().size() < 2) {
1391 MLIRContext *context, std::optional<Location> location,
1393 if (adaptor.getOperands().empty())
1396 auto isShapeType = [](
Type arg) {
1397 if (llvm::isa<ShapeType>(arg))
1403 Type acc = types.front();
1404 for (
auto t : drop_begin(types)) {
1405 Type l = acc, r = t;
1406 if (!llvm::isa<ShapeType, SizeType>(l))
1410 if (llvm::isa<SizeType>(l)) {
1411 if (llvm::isa<SizeType, IndexType>(r))
1415 }
else if (llvm::isa<IndexType>(l)) {
1416 if (llvm::isa<IndexType>(r))
1420 }
else if (llvm::isa<ShapeType>(l)) {
1427 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1428 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1429 if (ShapedType::isDynamic(rank1))
1431 else if (ShapedType::isDynamic(rank2))
1433 else if (rank1 != rank2)
1439 inferredReturnTypes.assign({acc});
1444 if (l.size() != 1 || r.size() != 1)
1449 Type lhs = l.front();
1450 Type rhs = r.front();
1452 if (!llvm::isa<ShapeType, SizeType>(lhs))
1453 std::swap(lhs, rhs);
1455 if (llvm::isa<SizeType>(lhs))
1456 return llvm::isa<SizeType, IndexType>(rhs);
1457 if (llvm::isa<ShapeType>(lhs))
1458 return llvm::isa<ShapeType, TensorType>(rhs);
1469 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1470 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1473 int64_t rank = shape.getNumElements();
1493 struct RankShapeOfCanonicalizationPattern
1499 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1502 auto rankedTensorType =
1503 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1504 if (!rankedTensorType)
1506 int64_t rank = rankedTensorType.getRank();
1507 if (llvm::isa<IndexType>(op.getType())) {
1510 }
else if (llvm::isa<shape::SizeType>(op.getType())) {
1522 patterns.
add<RankShapeOfCanonicalizationPattern>(context);
1526 MLIRContext *context, std::optional<Location> location,
1528 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1537 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1546 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1554 for (
auto value : llvm::cast<DenseIntElementsAttr>(shape))
1561 MLIRContext *context, std::optional<Location> location,
1562 NumElementsOp::Adaptor adaptor,
1564 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1571 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1574 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1587 if (getLhs() == getRhs())
1593 MLIRContext *context, std::optional<Location> location,
1595 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1596 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1603 if (l.size() != 1 || r.size() != 1)
1605 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1607 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1618 if (getLhs() == getRhs())
1624 MLIRContext *context, std::optional<Location> location,
1626 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1627 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1634 if (l.size() != 1 || r.size() != 1)
1636 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1638 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1648 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1651 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1654 APInt folded = lhs.getValue() * rhs.getValue();
1660 MLIRContext *context, std::optional<Location> location,
1662 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1663 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1672 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1682 auto type = llvm::dyn_cast<ShapedType>(getOperand().getType());
1683 if (!type || !type.hasStaticShape())
1695 if (!llvm::isa<ShapedType>(op.getArg().getType()))
1697 if (llvm::isa<ShapedType>(op.getType()))
1720 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1721 if (!ty || ty.getRank() != 1)
1724 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1729 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1730 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1741 patterns.
add<ShapeOfCastExtentTensor, ShapeOfWithTensor,
1742 ExtractFromShapeOfExtentTensor>(context);
1746 MLIRContext *context, std::optional<Location> location,
1748 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1751 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1753 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1756 inferredReturnTypes.assign({extentTensorTy});
1762 if (l.size() != 1 || r.size() != 1)
1767 Type lhs = l.front();
1768 Type rhs = r.front();
1770 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1771 !llvm::isa<ShapeType, ShapedType>(rhs))
1774 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1791 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1801 patterns.
add<IndexToSizeToIndexCanonicalization>(context);
1805 if (inputs.size() != 1 || outputs.size() != 1)
1807 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1808 llvm::isa<IndexType>(outputs[0]);
1816 auto *parentOp = (*this)->getParentOp();
1817 auto results = parentOp->getResults();
1818 auto operands = getOperands();
1820 if (parentOp->getNumResults() != getNumOperands())
1821 return emitOpError() <<
"number of operands does not match number of "
1822 "results of its parent";
1823 for (
auto e : llvm::zip(results, operands))
1824 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1825 return emitOpError() <<
"types mismatch between yield op and its parent";
1836 if (!adaptor.getOperand() || !adaptor.getIndex())
1838 auto shapeVec = llvm::to_vector<6>(
1839 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1841 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1844 int64_t rank = shape.size();
1845 if (-rank > splitPoint || splitPoint > rank)
1848 splitPoint += shape.size();
1849 Builder builder(adaptor.getOperand().getContext());
1859 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1860 if (!adaptor.getInput())
1863 auto shape = llvm::to_vector<6>(
1864 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1871 if (inputs.size() != 1 || outputs.size() != 1)
1873 if (
auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1874 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1875 inputTensor.getRank() != 1)
1877 }
else if (!llvm::isa<ShapeType>(inputs[0])) {
1881 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1882 return outputTensor && llvm::isa<IndexType>(outputTensor.
getElementType());
1900 if (
auto tensorType = llvm::dyn_cast<TensorType>(shape.
getType()))
1901 elementType = tensorType.getElementType();
1906 for (
Value initVal : initVals) {
1907 bodyBlock.
addArgument(initVal.getType(), initVal.getLoc());
1908 result.
addTypes(initVal.getType());
1917 auto blockArgsCount = getInitVals().size() + 2;
1919 return emitOpError() <<
"ReduceOp body is expected to have "
1920 << blockArgsCount <<
" arguments";
1925 "argument 0 of ReduceOp body is expected to be of IndexType");
1931 if (llvm::isa<ShapeType>(
getShape().getType())) {
1932 if (!llvm::isa<SizeType>(extentTy))
1933 return emitOpError(
"argument 1 of ReduceOp body is expected to be of "
1934 "SizeType if the ReduceOp operates on a ShapeType");
1936 if (!llvm::isa<IndexType>(extentTy))
1938 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1939 "ReduceOp operates on an extent tensor");
1944 return emitOpError() <<
"type mismatch between argument "
1946 <<
" of ReduceOp body and initial value "
1954 Type shapeOrExtentTensorType;
1963 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
1982 p <<
'(' <<
getShape() <<
", " << getInitVals()
1990 #define GET_OP_CLASSES
1991 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
1993 #define GET_TYPEDEF_CLASSES
1994 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static bool isErrorPropagationPossible(TypeRange operandTypes)
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
static LogicalResult verifySizeOrIndexOp(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static int64_t getNumElements(ShapedType type)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool use_empty()
Returns true if this operation has no uses.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
operand_iterator operand_begin()
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
operand_iterator operand_end()
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
void push_back(Block *block)
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
ValueTypeRange< ValueRange > type_range
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
bool isExtentTensorType(Type)
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
This header declares functions that assist transformations in the MemRef dialect.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.