27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/SmallString.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 #include "llvm/Support/raw_ostream.h"
35 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
38 #include "ShapeCanonicalization.inc"
46 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
47 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
53 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
56 llvm::append_range(shapeValues, type.getShape());
61 llvm::append_range(shapeValues, attr.getValues<int64_t>());
68 return llvm::any_of(operandTypes,
69 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
76 if (!llvm::isa<SizeType>(resultTy))
78 <<
"if at least one of the operands can hold error values then "
79 "the result must be of type `size` to propagate them";
88 if (!llvm::isa<ShapeType>(resultTy))
90 <<
"if at least one of the operands can hold error values then "
91 "the result must be of type `shape` to propagate them";
96 template <
typename... Ty>
98 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
101 template <
typename... Ty,
typename... ranges>
132 void ShapeDialect::initialize() {
135 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
138 #define GET_TYPEDEF_LIST
139 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
141 addInterfaces<ShapeInlinerInterface>();
145 allowUnknownOperations();
146 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
153 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
154 return builder.
create<ub::PoisonOp>(loc, type, poison);
157 return builder.
create<ConstShapeOp>(
158 loc, type, llvm::cast<DenseIntElementsAttr>(value));
159 if (llvm::isa<SizeType>(type))
160 return builder.
create<ConstSizeOp>(loc, type,
161 llvm::cast<IntegerAttr>(value));
162 if (llvm::isa<WitnessType>(type))
163 return builder.
create<ConstWitnessOp>(loc, type,
164 llvm::cast<BoolAttr>(value));
166 return arith::ConstantOp::materialize(builder, value, type, loc);
169 LogicalResult ShapeDialect::verifyOperationAttribute(
Operation *op,
172 if (attribute.
getName() ==
"shape.lib") {
175 "shape.lib attribute may only be on op implementing SymbolTable");
177 if (
auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.
getValue())) {
180 return op->
emitError(
"shape function library ")
181 << symbolRef <<
" not found";
182 return isa<shape::FunctionLibraryOp>(symbol)
185 << symbolRef <<
" required to be shape function library";
188 if (
auto arr = llvm::dyn_cast<ArrayAttr>(attribute.
getValue())) {
192 for (
auto it : arr) {
193 if (!llvm::isa<SymbolRefAttr>(it))
195 "only SymbolRefAttr allowed in shape.lib attribute array");
197 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
201 << it <<
" does not refer to FunctionLibraryOp";
202 for (
auto mapping : shapeFnLib.getMapping()) {
203 if (!key.insert(mapping.getName()).second) {
204 return op->
emitError(
"only one op to shape mapping allowed, found "
206 << mapping.getName() <<
"`";
213 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs "
214 "allowed as shape.lib attribute");
227 if (adaptor.getInputs().back())
228 return adaptor.getInputs().back();
264 bool yieldsResults = !getResults().empty();
266 p <<
" " << getWitness();
268 p <<
" -> (" << getResultTypes() <<
")";
281 LogicalResult matchAndRewrite(AssumingOp op,
283 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
284 if (!witness || !witness.getPassingAttr())
287 AssumingOp::inlineRegionIntoParent(op, rewriter);
292 struct AssumingOpRemoveUnusedResults :
public OpRewritePattern<AssumingOp> {
295 LogicalResult matchAndRewrite(AssumingOp op,
297 Block *body = op.getBody();
298 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
302 for (
auto [opResult, yieldOperand] :
303 llvm::zip(op.getResults(), yieldOp.getOperands())) {
304 if (!opResult.getUses().empty()) {
305 newYieldOperands.push_back(yieldOperand);
310 if (newYieldOperands.size() == yieldOp->getNumOperands())
319 auto newOp = rewriter.
create<AssumingOp>(
320 op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
321 newOp.getDoRegion().takeBody(op.getDoRegion());
325 auto src = newOp.getResults().begin();
326 for (
auto it : op.getResults()) {
327 if (it.getUses().empty())
328 replacementValues.push_back(
nullptr);
330 replacementValues.push_back(*src++);
332 rewriter.
replaceOp(op, replacementValues);
340 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
344 void AssumingOp::getSuccessorRegions(
357 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
360 auto *assumingBlock = op.getBody();
362 auto *blockAfterAssuming =
363 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
366 auto &yieldOp = assumingBlock->
back();
368 rewriter.
replaceOp(op, yieldOp.getOperands());
373 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
374 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
377 void AssumingOp::build(
391 for (
Value v : yieldValues)
392 assumingTypes.push_back(v.getType());
400 LogicalResult mlir::shape::AddOp::inferReturnTypes(
401 MLIRContext *context, std::optional<Location> location,
403 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
404 llvm::isa<SizeType>(adaptor.getRhs().getType()))
413 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
416 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
421 return constFoldBinaryOp<IntegerAttr>(
422 adaptor.getOperands(),
423 [](APInt a,
const APInt &b) { return std::move(a) + b; });
445 LogicalResult matchAndRewrite(AssumingAllOp op,
449 for (
Value operand : op.getInputs()) {
450 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
451 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
453 operands.push_back(operand);
457 if (operands.size() == op.getNumOperands())
487 struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
490 LogicalResult matchAndRewrite(AssumingAllOp op,
494 for (
Value operand : op.getInputs()) {
497 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
501 operands.insert(broadcastable);
505 if (operands.size() <= 1)
510 for (
auto cstr : operands) {
512 shapes.emplace_back(cstr, std::move(shapesSet));
516 llvm::sort(shapes, [](
auto a,
auto b) {
517 return a.first.getNumOperands() > b.first.getNumOperands();
526 for (
unsigned i = 0; i < shapes.size(); ++i) {
527 auto isSubset = [&](
auto pair) {
528 return llvm::set_is_subset(pair.second, shapes[i].second);
532 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
533 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
534 markedForErase.push_back(it0->first);
535 shapes.erase(it, shapes.end());
539 if (markedForErase.empty())
544 for (
auto &shape : shapes)
545 uniqueConstraints.push_back(shape.first.getResult());
551 for (
auto &op : markedForErase)
559 struct AssumingAllToCstrEqCanonicalization
563 LogicalResult matchAndRewrite(AssumingAllOp op,
566 for (
Value w : op.getInputs()) {
567 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
570 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](
Value s) {
571 return llvm::is_contained(shapes, s);
573 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
575 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
582 template <
typename OpTy>
586 LogicalResult matchAndRewrite(OpTy op,
592 if (unique.size() < op.getNumOperands()) {
594 unique.takeVector(), op->getAttrs());
606 .add<MergeAssumingAllOps, AssumingAllOneOp,
607 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
608 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
614 for (
int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
622 getOperation()->eraseOperand(idx);
625 if (!llvm::cast<BoolAttr>(a).getValue())
634 if (getNumOperands() == 0)
635 return emitOpError(
"no operands specified");
645 if (getShapes().size() == 1) {
649 return getShapes().front();
652 if (!adaptor.getShapes().front())
656 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
657 .getValues<int64_t>());
659 for (
auto next : adaptor.getShapes().drop_front()) {
662 auto nextShape = llvm::to_vector<6>(
663 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
672 std::copy(tmpShape.begin(), tmpShape.end(),
673 std::back_inserter(resultShape));
685 template <
typename OpTy>
689 LogicalResult matchAndRewrite(OpTy op,
691 auto isPotentiallyNonEmptyShape = [](
Value shape) {
692 if (
auto extentTensorTy =
693 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
694 if (extentTensorTy.getDimSize(0) == 0)
697 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
698 if (constShape.getShape().empty())
703 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
704 isPotentiallyNonEmptyShape);
708 if (newOperands.empty()) {
715 if (newOperands.size() < op.getNumOperands()) {
725 struct BroadcastForwardSingleOperandPattern
729 LogicalResult matchAndRewrite(BroadcastOp op,
731 if (op.getNumOperands() != 1)
733 Value replacement = op.getShapes().front();
736 if (replacement.
getType() != op.getType()) {
738 if (llvm::isa<ShapeType>(op.getType())) {
739 replacement = rewriter.
create<FromExtentTensorOp>(loc, replacement);
741 assert(!llvm::isa<ShapeType>(op.getType()) &&
742 !llvm::isa<ShapeType>(replacement.
getType()) &&
743 "expect extent tensor cast");
745 rewriter.
create<tensor::CastOp>(loc, op.getType(), replacement);
754 struct BroadcastFoldConstantOperandsPattern
758 LogicalResult matchAndRewrite(BroadcastOp op,
762 for (
Value shape : op.getShapes()) {
763 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
767 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
768 newFoldedConstantShape)) {
769 foldedConstantShape = newFoldedConstantShape;
773 newShapeOperands.push_back(shape);
777 if (op.getNumOperands() - newShapeOperands.size() < 2)
781 {
static_cast<int64_t
>(foldedConstantShape.size())},
783 newShapeOperands.push_back(rewriter.
create<ConstShapeOp>(
784 op.getLoc(), foldedConstantOperandsTy,
792 template <
typename OpTy>
793 struct CanonicalizeCastExtentTensorOperandsPattern
797 LogicalResult matchAndRewrite(OpTy op,
800 bool anyChange =
false;
801 auto canonicalizeOperand = [&](
Value operand) ->
Value {
802 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
804 bool isInformationLoosingCast =
805 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
806 if (isInformationLoosingCast) {
808 return castOp.getSource();
813 auto newOperands = llvm::to_vector<8>(
814 llvm::map_range(op.getOperands(), canonicalizeOperand));
824 struct BroadcastConcretizeResultTypePattern
828 LogicalResult matchAndRewrite(BroadcastOp op,
831 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
832 if (!resultTy || !resultTy.isDynamicDim(0))
837 for (
Value shape : op.getShapes()) {
838 if (
auto extentTensorTy =
839 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
842 if (extentTensorTy.isDynamicDim(0))
844 maxRank =
std::max(maxRank, extentTensorTy.getDimSize(0));
848 auto newOp = rewriter.
create<BroadcastOp>(
859 patterns.add<BroadcastConcretizeResultTypePattern,
860 BroadcastFoldConstantOperandsPattern,
861 BroadcastForwardSingleOperandPattern,
862 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
863 RemoveDuplicateOperandsPattern<BroadcastOp>,
864 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
872 if (!adaptor.getLhs() || !adaptor.getRhs())
874 auto lhsShape = llvm::to_vector<6>(
875 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
876 auto rhsShape = llvm::to_vector<6>(
877 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
879 resultShape.append(lhsShape.begin(), lhsShape.end());
880 resultShape.append(rhsShape.begin(), rhsShape.end());
893 interleaveComma(
getShape().getValues<int64_t>(), p);
908 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
913 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
916 ints.push_back(attr.getInt());
923 result.
types.push_back(resultTy);
927 OpFoldResult ConstShapeOp::fold(FoldAdaptor) {
return getShapeAttr(); }
931 patterns.add<TensorCastConstShape>(context);
934 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
935 MLIRContext *context, std::optional<Location> location,
938 const Properties prop = adaptor.getProperties();
940 {
static_cast<int64_t
>(prop.shape.size())}, b.getIndexType())});
944 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
946 if (l.size() != 1 || r.size() != 1)
949 Type lhs = l.front();
950 Type rhs = r.front();
952 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
962 void CstrBroadcastableOp::getCanonicalizationPatterns(
967 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
968 CstrBroadcastableEqOps,
969 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
970 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
976 bool nonScalarSeen =
false;
978 if (!a || llvm::cast<DenseIntElementsAttr>(a).
getNumElements() != 0) {
981 nonScalarSeen =
true;
987 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
994 for (
const auto &operand : adaptor.getShapes()) {
997 extents.push_back(llvm::to_vector<6>(
998 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1008 for (
auto shapeValue : getShapes()) {
1009 extents.emplace_back();
1010 if (failed(
getShapeVec(shapeValue, extents.back())))
1024 if (getNumOperands() < 2)
1025 return emitOpError(
"required at least 2 input shapes");
1036 patterns.add<CstrEqEqOps>(context);
1040 if (llvm::all_of(adaptor.getShapes(), [&](
Attribute a) {
1041 return a && a == adaptor.getShapes().front();
1060 OpFoldResult ConstSizeOp::fold(FoldAdaptor) {
return getValueAttr(); }
1062 void ConstSizeOp::getAsmResultNames(
1065 llvm::raw_svector_ostream os(buffer);
1066 os <<
"c" << getValue();
1067 setNameFn(getResult(), os.str());
1074 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {
return getPassingAttr(); }
1080 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1081 return adaptor.getPred();
1088 std::optional<int64_t> DimOp::getConstantIndex() {
1089 if (
auto constSizeOp =
getIndex().getDefiningOp<ConstSizeOp>())
1090 return constSizeOp.getValue().getLimitedValue();
1091 if (
auto constantOp =
getIndex().getDefiningOp<arith::ConstantOp>())
1092 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1093 return std::nullopt;
1097 Type valType = getValue().getType();
1098 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1099 if (!valShapedType || !valShapedType.hasRank())
1101 std::optional<int64_t> index = getConstantIndex();
1102 if (!index.has_value())
1104 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1106 auto extent = valShapedType.getDimSize(*index);
1107 if (ShapedType::isDynamic(extent))
1112 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1113 MLIRContext *context, std::optional<Location> location,
1115 inferredReturnTypes.assign({adaptor.getIndex().
getType()});
1120 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1128 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1131 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1132 if (!rhs || rhs.getValue().isZero())
1137 APInt quotient, remainder;
1138 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1139 if (quotient.isNegative() && !remainder.isZero()) {
1147 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1148 MLIRContext *context, std::optional<Location> location,
1150 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1151 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1160 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1170 bool allSame =
true;
1171 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1173 for (
Attribute operand : adaptor.getShapes().drop_front()) {
1176 allSame = allSame && operand == adaptor.getShapes().front();
1185 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1195 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1202 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1203 if (llvm::any_of(adaptor.getExtents(), [](
Attribute a) { return !a; }))
1206 for (
auto attr : adaptor.getExtents())
1207 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1222 FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1223 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1227 return lookupSymbol<FuncOp>(attr);
1233 StringAttr nameAttr;
1248 DictionaryAttr mappingAttr;
1260 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(),
"mapping"});
1272 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1276 FuncOp::build(builder, state, name, type, attrs);
1279 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1284 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1287 FuncOp func = create(location, name, type, attrs);
1288 func.setAllArgAttrs(argAttrs);
1295 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1297 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1299 state.attributes.append(attrs.begin(), attrs.end());
1302 if (argAttrs.empty())
1304 assert(type.getNumInputs() == argAttrs.size());
1306 builder, state, argAttrs, std::nullopt,
1307 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1311 auto buildFuncType =
1314 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1317 parser, result,
false,
1318 getFunctionTypeAttrName(result.
name), buildFuncType,
1319 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
1324 p, *
this,
false, getFunctionTypeAttrName(),
1325 getArgAttrsAttrName(), getResAttrsAttrName());
1332 std::optional<int64_t> GetExtentOp::getConstantDim() {
1333 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1334 return constSizeOp.getValue().getLimitedValue();
1335 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1336 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1337 return std::nullopt;
1341 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1344 std::optional<int64_t> dim = getConstantDim();
1345 if (!dim.has_value())
1347 if (dim.value() >= elements.getNumElements())
1349 return elements.getValues<
Attribute>()[(uint64_t)dim.value()];
1356 if (llvm::isa<ShapeType>(shape.
getType())) {
1357 Value dim = builder.
create<ConstSizeOp>(loc, dimAttr);
1358 build(builder, result, builder.
getType<SizeType>(), shape, dim);
1362 build(builder, result, builder.
getIndexType(), shape, dim);
1366 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1367 MLIRContext *context, std::optional<Location> location,
1373 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1376 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1387 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1390 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1392 if (adaptor.getShapes().size() < 2) {
1403 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1404 MLIRContext *context, std::optional<Location> location,
1406 if (adaptor.getOperands().empty())
1409 auto isShapeType = [](
Type arg) {
1410 if (llvm::isa<ShapeType>(arg))
1416 Type acc = types.front();
1417 for (
auto t : drop_begin(types)) {
1418 Type l = acc, r = t;
1419 if (!llvm::isa<ShapeType, SizeType>(l))
1423 if (llvm::isa<SizeType>(l)) {
1424 if (llvm::isa<SizeType, IndexType>(r))
1428 }
else if (llvm::isa<IndexType>(l)) {
1429 if (llvm::isa<IndexType>(r))
1433 }
else if (llvm::isa<ShapeType>(l)) {
1440 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1441 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1442 if (ShapedType::isDynamic(rank1))
1444 else if (ShapedType::isDynamic(rank2))
1446 else if (rank1 != rank2)
1452 inferredReturnTypes.assign({acc});
1457 if (l.size() != 1 || r.size() != 1)
1462 Type lhs = l.front();
1463 Type rhs = r.front();
1465 if (!llvm::isa<ShapeType, SizeType>(lhs))
1466 std::swap(lhs, rhs);
1468 if (llvm::isa<SizeType>(lhs))
1469 return llvm::isa<SizeType, IndexType>(rhs);
1470 if (llvm::isa<ShapeType>(lhs))
1471 return llvm::isa<ShapeType, TensorType>(rhs);
1482 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1483 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1486 int64_t rank = shape.getNumElements();
1506 struct RankShapeOfCanonicalizationPattern
1510 LogicalResult matchAndRewrite(shape::RankOp op,
1512 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1515 auto rankedTensorType =
1516 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1517 if (!rankedTensorType)
1519 int64_t rank = rankedTensorType.getRank();
1520 if (llvm::isa<IndexType>(op.getType())) {
1523 }
else if (llvm::isa<shape::SizeType>(op.getType())) {
1535 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1538 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1539 MLIRContext *context, std::optional<Location> location,
1541 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1550 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1559 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1567 for (
auto value : llvm::cast<DenseIntElementsAttr>(shape))
1573 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1574 MLIRContext *context, std::optional<Location> location,
1575 NumElementsOp::Adaptor adaptor,
1577 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1584 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1587 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1600 if (getLhs() == getRhs())
1605 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1606 MLIRContext *context, std::optional<Location> location,
1608 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1609 inferredReturnTypes.assign({adaptor.getLhs().
getType()});
1616 if (l.size() != 1 || r.size() != 1)
1618 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1620 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1631 if (getLhs() == getRhs())
1636 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1637 MLIRContext *context, std::optional<Location> location,
1639 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1640 inferredReturnTypes.assign({adaptor.getLhs().
getType()});
1647 if (l.size() != 1 || r.size() != 1)
1649 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1651 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1661 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1664 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1667 APInt folded = lhs.getValue() * rhs.getValue();
1672 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1673 MLIRContext *context, std::optional<Location> location,
1675 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1676 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1685 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1696 struct ShapeOfOpToConstShapeOp :
public OpRewritePattern<shape::ShapeOfOp> {
1699 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1701 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1702 if (!type || !type.hasStaticShape())
1707 .
create<ConstShapeOp>(loc,
1710 if (constShape.
getType() != op.getResult().getType())
1711 constShape = rewriter.
create<tensor::CastOp>(
1712 loc, op.getResult().getType(), constShape);
1731 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1733 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1734 if (!tensorReshapeOp)
1736 if (!isa<TensorType>(op.getType()))
1747 Value shape = tensorReshapeOp.getShape();
1749 auto opTensorTy = cast<RankedTensorType>(op.getType());
1750 auto shapeTensorTy = cast<RankedTensorType>(shape.
getType());
1752 if (opTensorTy != shapeTensorTy) {
1753 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1754 shape = rewriter.
create<tensor::CastOp>(op.getLoc(), opTensorTy, shape);
1757 rewriter.
create<arith::IndexCastOp>(op.getLoc(), opTensorTy, shape);
1777 LogicalResult matchAndRewrite(tensor::CastOp op,
1779 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1780 if (!ty || ty.getRank() != 1)
1783 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1788 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1789 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1800 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1801 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1805 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1806 MLIRContext *context, std::optional<Location> location,
1808 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1811 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1813 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1816 inferredReturnTypes.assign({extentTensorTy});
1822 if (l.size() != 1 || r.size() != 1)
1827 Type lhs = l.front();
1828 Type rhs = r.front();
1830 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1831 !llvm::isa<ShapeType, ShapedType>(rhs))
1834 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1851 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1861 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1865 if (inputs.size() != 1 || outputs.size() != 1)
1867 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1868 llvm::isa<IndexType>(outputs[0]);
1876 auto *parentOp = (*this)->getParentOp();
1877 auto results = parentOp->getResults();
1878 auto operands = getOperands();
1880 if (parentOp->getNumResults() != getNumOperands())
1881 return emitOpError() <<
"number of operands does not match number of "
1882 "results of its parent";
1883 for (
auto e : llvm::zip(results, operands))
1885 return emitOpError() <<
"types mismatch between yield op and its parent";
1894 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1896 if (!adaptor.getOperand() || !adaptor.getIndex())
1898 auto shapeVec = llvm::to_vector<6>(
1899 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1901 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1904 int64_t rank = shape.size();
1905 if (-rank > splitPoint || splitPoint > rank)
1908 splitPoint += shape.size();
1909 Builder builder(adaptor.getOperand().getContext());
1919 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1920 if (!adaptor.getInput())
1923 auto shape = llvm::to_vector<6>(
1924 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1931 if (inputs.size() != 1 || outputs.size() != 1)
1933 if (
auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1934 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1935 inputTensor.getRank() != 1)
1937 }
else if (!llvm::isa<ShapeType>(inputs[0])) {
1941 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1942 return outputTensor && llvm::isa<IndexType>(outputTensor.
getElementType());
1960 if (
auto tensorType = llvm::dyn_cast<TensorType>(shape.
getType()))
1961 elementType = tensorType.getElementType();
1966 for (
Value initVal : initVals) {
1967 bodyBlock->
addArgument(initVal.getType(), initVal.getLoc());
1968 result.
addTypes(initVal.getType());
1977 auto blockArgsCount = getInitVals().size() + 2;
1979 return emitOpError() <<
"ReduceOp body is expected to have "
1980 << blockArgsCount <<
" arguments";
1985 "argument 0 of ReduceOp body is expected to be of IndexType");
1992 if (!llvm::isa<SizeType>(extentTy))
1993 return emitOpError(
"argument 1 of ReduceOp body is expected to be of "
1994 "SizeType if the ReduceOp operates on a ShapeType");
1996 if (!llvm::isa<IndexType>(extentTy))
1998 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1999 "ReduceOp operates on an extent tensor");
2004 return emitOpError() <<
"type mismatch between argument "
2006 <<
" of ReduceOp body and initial value "
2014 Type shapeOrExtentTensorType;
2023 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
2042 p <<
'(' <<
getShape() <<
", " << getInitVals()
2050 #define GET_OP_CLASSES
2051 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2053 #define GET_TYPEDEF_CLASSES
2054 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isErrorPropagationPossible(TypeRange operandTypes)
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
static LogicalResult verifySizeOrIndexOp(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
ValueTypeRange< ValueRange > type_range
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
bool isExtentTensorType(Type)
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.