27 #include "llvm/ADT/SetOperations.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/raw_ostream.h"
34 #include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
37 #include "ShapeCanonicalization.inc"
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
55 llvm::append_range(shapeValues, type.getShape());
60 llvm::append_range(shapeValues, attr.getValues<int64_t>());
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
75 if (!llvm::isa<SizeType>(resultTy))
77 <<
"if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
87 if (!llvm::isa<ShapeType>(resultTy))
89 <<
"if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
95 template <
typename... Ty>
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
100 template <
typename... Ty,
typename... ranges>
131 void ShapeDialect::initialize() {
134 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
137 #define GET_TYPEDEF_LIST
138 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140 addInterfaces<ShapeInlinerInterface>();
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
152 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
165 return arith::ConstantOp::materialize(builder, value, type, loc);
168 LogicalResult ShapeDialect::verifyOperationAttribute(
Operation *op,
171 if (attribute.
getName() ==
"shape.lib") {
174 "shape.lib attribute may only be on op implementing SymbolTable");
176 if (
auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.
getValue())) {
179 return op->
emitError(
"shape function library ")
180 << symbolRef <<
" not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
184 << symbolRef <<
" required to be shape function library";
187 if (
auto arr = llvm::dyn_cast<ArrayAttr>(attribute.
getValue())) {
191 for (
auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
194 "only SymbolRefAttr allowed in shape.lib attribute array");
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
200 << it <<
" does not refer to FunctionLibraryOp";
201 for (
auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->
emitError(
"only one op to shape mapping allowed, found "
205 << mapping.getName() <<
"`";
212 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
263 bool yieldsResults = !getResults().empty();
265 p <<
" " << getWitness();
267 p <<
" -> (" << getResultTypes() <<
")";
280 LogicalResult matchAndRewrite(AssumingOp op,
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
291 struct AssumingOpRemoveUnusedResults :
public OpRewritePattern<AssumingOp> {
294 LogicalResult matchAndRewrite(AssumingOp op,
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
301 for (
auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
324 auto src = newOp.getResults().begin();
325 for (
auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(
nullptr);
329 replacementValues.push_back(*src++);
331 rewriter.
replaceOp(op, replacementValues);
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
343 void AssumingOp::getSuccessorRegions(
356 void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
359 auto *assumingBlock = op.getBody();
361 auto *blockAfterAssuming =
362 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
365 auto &yieldOp = assumingBlock->
back();
367 rewriter.
replaceOp(op, yieldOp.getOperands());
372 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
373 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
376 void AssumingOp::build(
387 AssumingYieldOp::create(builder, result.
location, yieldValues);
390 for (
Value v : yieldValues)
391 assumingTypes.push_back(v.getType());
399 LogicalResult mlir::shape::AddOp::inferReturnTypes(
400 MLIRContext *context, std::optional<Location> location,
402 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
403 llvm::isa<SizeType>(adaptor.getRhs().getType()))
412 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
415 OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
420 return constFoldBinaryOp<IntegerAttr>(
421 adaptor.getOperands(),
422 [](APInt a,
const APInt &b) { return std::move(a) + b; });
444 LogicalResult matchAndRewrite(AssumingAllOp op,
448 for (
Value operand : op.getInputs()) {
449 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
450 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
452 operands.push_back(operand);
456 if (operands.size() == op.getNumOperands())
486 struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
489 LogicalResult matchAndRewrite(AssumingAllOp op,
493 for (
Value operand : op.getInputs()) {
496 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
500 operands.insert(broadcastable);
504 if (operands.size() <= 1)
509 for (
auto cstr : operands) {
511 shapes.emplace_back(cstr, std::move(shapesSet));
515 llvm::sort(shapes, [](
auto a,
auto b) {
516 return a.first.getNumOperands() > b.first.getNumOperands();
525 for (
unsigned i = 0; i < shapes.size(); ++i) {
526 auto isSubset = [&](
auto pair) {
527 return llvm::set_is_subset(pair.second, shapes[i].second);
531 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
532 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
533 markedForErase.push_back(it0->first);
534 shapes.erase(it, shapes.end());
538 if (markedForErase.empty())
543 for (
auto &shape : shapes)
544 uniqueConstraints.push_back(shape.first.getResult());
550 for (
auto &op : markedForErase)
558 struct AssumingAllToCstrEqCanonicalization
562 LogicalResult matchAndRewrite(AssumingAllOp op,
565 for (
Value w : op.getInputs()) {
566 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
569 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](
Value s) {
570 return llvm::is_contained(shapes, s);
572 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
574 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
581 template <
typename OpTy>
585 LogicalResult matchAndRewrite(OpTy op,
591 if (unique.size() < op.getNumOperands()) {
593 unique.takeVector(), op->getAttrs());
605 .add<MergeAssumingAllOps, AssumingAllOneOp,
606 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
607 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
613 for (
int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
621 getOperation()->eraseOperand(idx);
624 if (!llvm::cast<BoolAttr>(a).getValue())
633 if (getNumOperands() == 0)
634 return emitOpError(
"no operands specified");
644 if (getShapes().size() == 1) {
648 return getShapes().front();
651 if (!adaptor.getShapes().front())
655 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
656 .getValues<int64_t>());
658 for (
auto next : adaptor.getShapes().drop_front()) {
661 auto nextShape = llvm::to_vector<6>(
662 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
671 std::copy(tmpShape.begin(), tmpShape.end(),
672 std::back_inserter(resultShape));
684 template <
typename OpTy>
688 LogicalResult matchAndRewrite(OpTy op,
690 auto isPotentiallyNonEmptyShape = [](
Value shape) {
691 if (
auto extentTensorTy =
692 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
693 if (extentTensorTy.getDimSize(0) == 0)
696 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
697 if (constShape.getShape().empty())
702 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
703 isPotentiallyNonEmptyShape);
707 if (newOperands.empty()) {
714 if (newOperands.size() < op.getNumOperands()) {
724 struct BroadcastForwardSingleOperandPattern
728 LogicalResult matchAndRewrite(BroadcastOp op,
730 if (op.getNumOperands() != 1)
732 Value replacement = op.getShapes().front();
735 if (replacement.
getType() != op.getType()) {
736 auto loc = op.getLoc();
737 if (llvm::isa<ShapeType>(op.getType())) {
738 replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
740 assert(!llvm::isa<ShapeType>(op.getType()) &&
741 !llvm::isa<ShapeType>(replacement.
getType()) &&
742 "expect extent tensor cast");
744 tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
753 struct BroadcastFoldConstantOperandsPattern
757 LogicalResult matchAndRewrite(BroadcastOp op,
761 for (
Value shape : op.getShapes()) {
762 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
766 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
767 newFoldedConstantShape)) {
768 foldedConstantShape = newFoldedConstantShape;
772 newShapeOperands.push_back(shape);
776 if (op.getNumOperands() - newShapeOperands.size() < 2)
780 {
static_cast<int64_t
>(foldedConstantShape.size())},
782 newShapeOperands.push_back(
783 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
791 template <
typename OpTy>
792 struct CanonicalizeCastExtentTensorOperandsPattern
796 LogicalResult matchAndRewrite(OpTy op,
799 bool anyChange =
false;
800 auto canonicalizeOperand = [&](
Value operand) ->
Value {
801 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
803 bool isInformationLoosingCast =
804 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
805 if (isInformationLoosingCast) {
807 return castOp.getSource();
812 auto newOperands = llvm::to_vector<8>(
813 llvm::map_range(op.getOperands(), canonicalizeOperand));
823 struct BroadcastConcretizeResultTypePattern
827 LogicalResult matchAndRewrite(BroadcastOp op,
830 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
831 if (!resultTy || !resultTy.isDynamicDim(0))
836 for (
Value shape : op.getShapes()) {
837 if (
auto extentTensorTy =
838 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
841 if (extentTensorTy.isDynamicDim(0))
843 maxRank =
std::max(maxRank, extentTensorTy.getDimSize(0));
847 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
858 patterns.add<BroadcastConcretizeResultTypePattern,
859 BroadcastFoldConstantOperandsPattern,
860 BroadcastForwardSingleOperandPattern,
861 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
862 RemoveDuplicateOperandsPattern<BroadcastOp>,
863 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
871 if (!adaptor.getLhs() || !adaptor.getRhs())
873 auto lhsShape = llvm::to_vector<6>(
874 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
875 auto rhsShape = llvm::to_vector<6>(
876 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
878 resultShape.append(lhsShape.begin(), lhsShape.end());
879 resultShape.append(rhsShape.begin(), rhsShape.end());
892 interleaveComma(
getShape().getValues<int64_t>(), p);
907 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
912 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
915 ints.push_back(attr.getInt());
922 result.
types.push_back(resultTy);
926 OpFoldResult ConstShapeOp::fold(FoldAdaptor) {
return getShapeAttr(); }
930 patterns.add<TensorCastConstShape>(context);
933 LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
934 MLIRContext *context, std::optional<Location> location,
937 const Properties prop = adaptor.getProperties();
939 {
static_cast<int64_t
>(prop.shape.size())}, b.getIndexType())});
943 bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
945 if (l.size() != 1 || r.size() != 1)
948 Type lhs = l.front();
949 Type rhs = r.front();
951 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
961 void CstrBroadcastableOp::getCanonicalizationPatterns(
966 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
967 CstrBroadcastableEqOps,
968 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
969 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
975 bool nonScalarSeen =
false;
977 if (!a || llvm::cast<DenseIntElementsAttr>(a).
getNumElements() != 0) {
980 nonScalarSeen =
true;
986 OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
993 for (
const auto &operand : adaptor.getShapes()) {
996 extents.push_back(llvm::to_vector<6>(
997 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1007 for (
auto shapeValue : getShapes()) {
1008 extents.emplace_back();
1009 if (failed(
getShapeVec(shapeValue, extents.back())))
1023 if (getNumOperands() < 2)
1024 return emitOpError(
"required at least 2 input shapes");
1035 patterns.add<CstrEqEqOps>(context);
1039 if (llvm::all_of(adaptor.getShapes(), [&](
Attribute a) {
1040 return a && a == adaptor.getShapes().front();
1059 OpFoldResult ConstSizeOp::fold(FoldAdaptor) {
return getValueAttr(); }
1061 void ConstSizeOp::getAsmResultNames(
1064 llvm::raw_svector_ostream os(buffer);
1065 os <<
"c" << getValue();
1066 setNameFn(getResult(), os.str());
1073 OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {
return getPassingAttr(); }
1079 OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1080 return adaptor.getPred();
1087 std::optional<int64_t> DimOp::getConstantIndex() {
1088 if (
auto constSizeOp =
getIndex().getDefiningOp<ConstSizeOp>())
1089 return constSizeOp.getValue().getLimitedValue();
1090 if (
auto constantOp =
getIndex().getDefiningOp<arith::ConstantOp>())
1091 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1092 return std::nullopt;
1096 Type valType = getValue().getType();
1097 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1098 if (!valShapedType || !valShapedType.hasRank())
1100 std::optional<int64_t> index = getConstantIndex();
1101 if (!index.has_value())
1103 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1105 auto extent = valShapedType.getDimSize(*index);
1106 if (ShapedType::isDynamic(extent))
1111 LogicalResult mlir::shape::DimOp::inferReturnTypes(
1112 MLIRContext *context, std::optional<Location> location,
1114 inferredReturnTypes.assign({adaptor.getIndex().
getType()});
1119 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1127 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1130 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1131 if (!rhs || rhs.getValue().isZero())
1136 APInt quotient, remainder;
1137 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1138 if (quotient.isNegative() && !remainder.isZero()) {
1146 LogicalResult mlir::shape::DivOp::inferReturnTypes(
1147 MLIRContext *context, std::optional<Location> location,
1149 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1150 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1159 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1169 bool allSame =
true;
1170 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1172 for (
Attribute operand : adaptor.getShapes().drop_front()) {
1175 allSame = allSame && operand == adaptor.getShapes().front();
1184 OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1194 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1201 OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1202 if (llvm::any_of(adaptor.getExtents(), [](
Attribute a) { return !a; }))
1205 for (
auto attr : adaptor.getExtents())
1206 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1221 FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1222 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1226 return lookupSymbol<FuncOp>(attr);
1232 StringAttr nameAttr;
1247 DictionaryAttr mappingAttr;
1259 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(),
"mapping"});
1271 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1275 FuncOp::build(builder, state, name, type, attrs);
1278 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1283 FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1286 FuncOp func = create(location, name, type, attrs);
1287 func.setAllArgAttrs(argAttrs);
1294 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1296 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1298 state.attributes.append(attrs.begin(), attrs.end());
1301 if (argAttrs.empty())
1303 assert(type.getNumInputs() == argAttrs.size());
1305 builder, state, argAttrs, {},
1306 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1310 auto buildFuncType =
1313 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1316 parser, result,
false,
1317 getFunctionTypeAttrName(result.
name), buildFuncType,
1318 getArgAttrsAttrName(result.
name), getResAttrsAttrName(result.
name));
1323 p, *
this,
false, getFunctionTypeAttrName(),
1324 getArgAttrsAttrName(), getResAttrsAttrName());
1331 std::optional<int64_t> GetExtentOp::getConstantDim() {
1332 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1333 return constSizeOp.getValue().getLimitedValue();
1334 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1335 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1336 return std::nullopt;
1340 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1343 std::optional<int64_t> dim = getConstantDim();
1344 if (!dim.has_value())
1346 if (dim.value() >= elements.getNumElements())
1348 return elements.getValues<
Attribute>()[(uint64_t)dim.value()];
1355 if (llvm::isa<ShapeType>(shape.
getType())) {
1356 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1357 build(builder, result, builder.
getType<SizeType>(), shape, dim);
1361 build(builder, result, builder.
getIndexType(), shape, dim);
1365 LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1366 MLIRContext *context, std::optional<Location> location,
1372 bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1375 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1386 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1389 OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1391 if (adaptor.getShapes().size() < 2) {
1402 LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1403 MLIRContext *context, std::optional<Location> location,
1405 if (adaptor.getOperands().empty())
1408 auto isShapeType = [](
Type arg) {
1409 if (llvm::isa<ShapeType>(arg))
1415 Type acc = types.front();
1416 for (
auto t : drop_begin(types)) {
1417 Type l = acc, r = t;
1418 if (!llvm::isa<ShapeType, SizeType>(l))
1422 if (llvm::isa<SizeType>(l)) {
1423 if (llvm::isa<SizeType, IndexType>(r))
1427 }
else if (llvm::isa<IndexType>(l)) {
1428 if (llvm::isa<IndexType>(r))
1432 }
else if (llvm::isa<ShapeType>(l)) {
1439 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1440 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1441 if (ShapedType::isDynamic(rank1))
1443 else if (ShapedType::isDynamic(rank2))
1445 else if (rank1 != rank2)
1451 inferredReturnTypes.assign({acc});
1456 if (l.size() != 1 || r.size() != 1)
1461 Type lhs = l.front();
1462 Type rhs = r.front();
1464 if (!llvm::isa<ShapeType, SizeType>(lhs))
1465 std::swap(lhs, rhs);
1467 if (llvm::isa<SizeType>(lhs))
1468 return llvm::isa<SizeType, IndexType>(rhs);
1469 if (llvm::isa<ShapeType>(lhs))
1470 return llvm::isa<ShapeType, TensorType>(rhs);
1481 OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1482 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1485 int64_t rank = shape.getNumElements();
1505 struct RankShapeOfCanonicalizationPattern
1509 LogicalResult matchAndRewrite(shape::RankOp op,
1511 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1514 auto rankedTensorType =
1515 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1516 if (!rankedTensorType)
1518 int64_t rank = rankedTensorType.getRank();
1519 if (llvm::isa<IndexType>(op.getType())) {
1522 }
else if (llvm::isa<shape::SizeType>(op.getType())) {
1534 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1537 LogicalResult mlir::shape::RankOp::inferReturnTypes(
1538 MLIRContext *context, std::optional<Location> location,
1540 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1549 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1558 OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1566 for (
auto value : llvm::cast<DenseIntElementsAttr>(shape))
1572 LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1573 MLIRContext *context, std::optional<Location> location,
1574 NumElementsOp::Adaptor adaptor,
1576 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1583 bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1586 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1599 if (getLhs() == getRhs())
1604 LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1605 MLIRContext *context, std::optional<Location> location,
1607 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1608 inferredReturnTypes.assign({adaptor.getLhs().
getType()});
1615 if (l.size() != 1 || r.size() != 1)
1617 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1619 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1630 if (getLhs() == getRhs())
1635 LogicalResult mlir::shape::MinOp::inferReturnTypes(
1636 MLIRContext *context, std::optional<Location> location,
1638 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1639 inferredReturnTypes.assign({adaptor.getLhs().
getType()});
1646 if (l.size() != 1 || r.size() != 1)
1648 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1650 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1660 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1663 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1666 APInt folded = lhs.getValue() * rhs.getValue();
1671 LogicalResult mlir::shape::MulOp::inferReturnTypes(
1672 MLIRContext *context, std::optional<Location> location,
1674 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1675 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1684 return eachHasOnlyOneOfTypes<SizeType, IndexType>(l, r);
1695 struct ShapeOfOpToConstShapeOp :
public OpRewritePattern<shape::ShapeOfOp> {
1698 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1700 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1701 if (!type || !type.hasStaticShape())
1705 ConstShapeOp::create(rewriter, loc,
1708 if (constShape.
getType() != op.getResult().getType())
1709 constShape = tensor::CastOp::create(rewriter, loc,
1710 op.getResult().getType(), constShape);
1729 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1731 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1732 if (!tensorReshapeOp)
1734 if (!isa<TensorType>(op.getType()))
1745 Value shape = tensorReshapeOp.getShape();
1747 auto opTensorTy = cast<RankedTensorType>(op.getType());
1748 auto shapeTensorTy = cast<RankedTensorType>(shape.
getType());
1750 if (opTensorTy != shapeTensorTy) {
1751 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1753 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1755 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1776 LogicalResult matchAndRewrite(tensor::CastOp op,
1778 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1779 if (!ty || ty.getRank() != 1)
1782 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1787 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1788 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1799 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1800 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1804 LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1805 MLIRContext *context, std::optional<Location> location,
1807 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1810 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1812 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1815 inferredReturnTypes.assign({extentTensorTy});
1821 if (l.size() != 1 || r.size() != 1)
1826 Type lhs = l.front();
1827 Type rhs = r.front();
1829 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1830 !llvm::isa<ShapeType, ShapedType>(rhs))
1833 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1850 OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1860 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1864 if (inputs.size() != 1 || outputs.size() != 1)
1866 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1867 llvm::isa<IndexType>(outputs[0]);
1875 auto *parentOp = (*this)->getParentOp();
1876 auto results = parentOp->getResults();
1877 auto operands = getOperands();
1879 if (parentOp->getNumResults() != getNumOperands())
1880 return emitOpError() <<
"number of operands does not match number of "
1881 "results of its parent";
1882 for (
auto e : llvm::zip(results, operands))
1884 return emitOpError() <<
"types mismatch between yield op and its parent";
1893 LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1895 if (!adaptor.getOperand() || !adaptor.getIndex())
1897 auto shapeVec = llvm::to_vector<6>(
1898 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1900 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1903 int64_t rank = shape.size();
1904 if (-rank > splitPoint || splitPoint > rank)
1907 splitPoint += shape.size();
1908 Builder builder(adaptor.getOperand().getContext());
1918 OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1919 if (!adaptor.getInput())
1922 auto shape = llvm::to_vector<6>(
1923 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1930 if (inputs.size() != 1 || outputs.size() != 1)
1932 if (
auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1933 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1934 inputTensor.getRank() != 1)
1936 }
else if (!llvm::isa<ShapeType>(inputs[0])) {
1940 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1941 return outputTensor && llvm::isa<IndexType>(outputTensor.
getElementType());
1959 if (
auto tensorType = llvm::dyn_cast<TensorType>(shape.
getType()))
1960 elementType = tensorType.getElementType();
1965 for (
Value initVal : initVals) {
1966 bodyBlock->
addArgument(initVal.getType(), initVal.getLoc());
1967 result.
addTypes(initVal.getType());
1976 auto blockArgsCount = getInitVals().size() + 2;
1978 return emitOpError() <<
"ReduceOp body is expected to have "
1979 << blockArgsCount <<
" arguments";
1984 "argument 0 of ReduceOp body is expected to be of IndexType");
1991 if (!llvm::isa<SizeType>(extentTy))
1992 return emitOpError(
"argument 1 of ReduceOp body is expected to be of "
1993 "SizeType if the ReduceOp operates on a ShapeType");
1995 if (!llvm::isa<IndexType>(extentTy))
1997 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1998 "ReduceOp operates on an extent tensor");
2003 return emitOpError() <<
"type mismatch between argument "
2005 <<
" of ReduceOp body and initial value "
2013 Type shapeOrExtentTensorType;
2022 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
2041 p <<
'(' <<
getShape() <<
", " << getInitVals()
2049 #define GET_OP_CLASSES
2050 #include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2052 #define GET_TYPEDEF_CLASSES
2053 #include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static bool isErrorPropagationPossible(TypeRange operandTypes)
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
static LogicalResult verifySizeOrIndexOp(Operation *op)
static Operation * materializeConstant(Dialect *dialect, OpBuilder &builder, Attribute value, Type type, Location loc)
A utility function used to materialize a constant for a given attribute and type.
static int64_t product(ArrayRef< int64_t > vals)
static MLIRContext * getContext(OpFoldResult val)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static void print(spirv::VerCapExtAttr triple, DialectAsmPrinter &printer)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void push_back(NamedAttribute newAttribute)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
ValueTypeRange< ValueRange > type_range
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 >> shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
QueryRef parse(llvm::StringRef line, const QuerySession &qs)
bool isExtentTensorType(Type)
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
SmallVector< Value, 4 > operands
void addOperands(ValueRange newOperands)
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
void addTypes(ArrayRef< Type > newTypes)
SmallVector< std::unique_ptr< Region >, 1 > regions
Regions that the op will hold.
SmallVector< Type, 4 > types
Types of the results of this operation.
Region * addRegion()
Create a region that should be attached to the operation.