25#include "llvm/ADT/SetOperations.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
37#include "ShapeCanonicalization.inc"
41 return RankedTensorType::get({rank}, IndexType::get(ctx));
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
55 llvm::append_range(shapeValues, type.getShape());
60 llvm::append_range(shapeValues, attr.getValues<
int64_t>());
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
75 if (!llvm::isa<SizeType>(resultTy))
77 <<
"if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
87 if (!llvm::isa<ShapeType>(resultTy))
89 <<
"if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
95template <
typename... Ty>
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
100template <
typename... Ty,
typename... ranges>
111struct ShapeInlinerInterface :
public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
117 IRMapping &)
const final {
125 IRMapping &)
const final {
131void ShapeDialect::initialize() {
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140 addInterfaces<ShapeInlinerInterface>();
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
152 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
165 return arith::ConstantOp::materialize(builder, value, type, loc);
168LogicalResult ShapeDialect::verifyOperationAttribute(
Operation *op,
171 if (attribute.
getName() ==
"shape.lib") {
174 "shape.lib attribute may only be on op implementing SymbolTable");
176 if (
auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.
getValue())) {
179 return op->
emitError(
"shape function library ")
180 << symbolRef <<
" not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
184 << symbolRef <<
" required to be shape function library";
187 if (
auto arr = llvm::dyn_cast<ArrayAttr>(attribute.
getValue())) {
191 for (
auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
194 "only SymbolRefAttr allowed in shape.lib attribute array");
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
200 << it <<
" does not refer to FunctionLibraryOp";
201 for (
auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->
emitError(
"only one op to shape mapping allowed, found "
205 << mapping.getName() <<
"`";
212 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
237 result.regions.reserve(1);
254 AssumingOp::ensureTerminator(*doRegion, parser.
getBuilder(),
result.location);
263 bool yieldsResults = !getResults().empty();
265 p <<
" " << getWitness();
267 p <<
" -> (" << getResultTypes() <<
")";
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter)
const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter)
const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
300 SmallVector<Value, 4> newYieldOperands;
301 for (
auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (
auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(
nullptr);
329 replacementValues.push_back(*src++);
331 rewriter.
replaceOp(op, replacementValues);
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
343void AssumingOp::getSuccessorRegions(
360void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
363 auto *assumingBlock = op.getBody();
365 auto *blockAfterAssuming =
366 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
369 auto &yieldOp = assumingBlock->
back();
371 rewriter.
replaceOp(op, yieldOp.getOperands());
376 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
377 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
380void AssumingOp::build(
385 result.addOperands(witness);
391 AssumingYieldOp::create(builder,
result.location, yieldValues);
394 for (
Value v : yieldValues)
395 assumingTypes.push_back(v.getType());
396 result.addTypes(assumingTypes);
403LogicalResult mlir::shape::AddOp::inferReturnTypes(
404 MLIRContext *context, std::optional<Location> location,
406 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
407 llvm::isa<SizeType>(adaptor.getRhs().getType()))
408 inferredReturnTypes.assign({SizeType::get(context)});
410 inferredReturnTypes.assign({IndexType::get(context)});
419OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
425 adaptor.getOperands(),
426 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
446 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
448 LogicalResult matchAndRewrite(AssumingAllOp op,
449 PatternRewriter &rewriter)
const override {
450 SmallVector<Value> operands;
452 for (Value operand : op.getInputs()) {
453 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
454 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
456 operands.push_back(operand);
460 if (operands.size() == op.getNumOperands())
490struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
491 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
493 LogicalResult matchAndRewrite(AssumingAllOp op,
494 PatternRewriter &rewriter)
const override {
497 for (Value operand : op.getInputs()) {
500 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
504 operands.insert(broadcastable);
508 if (operands.size() <= 1)
512 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
513 for (
auto cstr : operands) {
515 shapes.emplace_back(cstr, std::move(shapesSet));
519 llvm::sort(shapes, [](
auto a,
auto b) {
520 return a.first.getNumOperands() >
b.first.getNumOperands();
527 SmallVector<CstrBroadcastableOp> markedForErase;
529 for (
unsigned i = 0; i < shapes.size(); ++i) {
530 auto isSubset = [&](
auto pair) {
531 return llvm::set_is_subset(pair.second, shapes[i].second);
535 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
536 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
537 markedForErase.push_back(it0->first);
538 shapes.erase(it, shapes.end());
542 if (markedForErase.empty())
546 SmallVector<Value> uniqueConstraints;
547 for (
auto &shape : shapes)
548 uniqueConstraints.push_back(shape.first.getResult());
554 for (
auto &op : markedForErase)
562struct AssumingAllToCstrEqCanonicalization
564 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
566 LogicalResult matchAndRewrite(AssumingAllOp op,
567 PatternRewriter &rewriter)
const override {
568 SmallVector<Value, 8> shapes;
569 for (Value w : op.getInputs()) {
570 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
573 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
574 return llvm::is_contained(shapes, s);
576 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
578 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
585template <
typename OpTy>
587 using OpRewritePattern<OpTy>::OpRewritePattern;
589 LogicalResult matchAndRewrite(OpTy op,
590 PatternRewriter &rewriter)
const override {
595 if (unique.size() < op.getNumOperands()) {
597 unique.takeVector(), op->getAttrs());
609 .add<MergeAssumingAllOps, AssumingAllOneOp,
610 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
611 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
617 for (
int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
625 getOperation()->eraseOperand(idx);
628 if (!llvm::cast<BoolAttr>(a).getValue())
635LogicalResult AssumingAllOp::verify() {
637 if (getNumOperands() == 0)
648 if (getShapes().size() == 1) {
652 return getShapes().front();
655 if (!adaptor.getShapes().front())
659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
662 for (
auto next : adaptor.getShapes().drop_front()) {
665 auto nextShape = llvm::to_vector<6>(
666 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
675 std::copy(tmpShape.begin(), tmpShape.end(),
676 std::back_inserter(resultShape));
683LogicalResult BroadcastOp::verify() {
688template <
typename OpTy>
690 using OpRewritePattern<OpTy>::OpRewritePattern;
692 LogicalResult matchAndRewrite(OpTy op,
693 PatternRewriter &rewriter)
const override {
694 auto isPotentiallyNonEmptyShape = [](Value shape) {
695 if (
auto extentTensorTy =
696 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
697 if (extentTensorTy.getDimSize(0) == 0)
700 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
701 if (constShape.getShape().empty())
706 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
707 isPotentiallyNonEmptyShape);
711 if (newOperands.empty()) {
718 if (newOperands.size() < op.getNumOperands()) {
728struct BroadcastForwardSingleOperandPattern
730 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
732 LogicalResult matchAndRewrite(BroadcastOp op,
733 PatternRewriter &rewriter)
const override {
734 if (op.getNumOperands() != 1)
740 auto loc = op.getLoc();
741 if (llvm::isa<ShapeType>(op.getType())) {
744 assert(!llvm::isa<ShapeType>(op.getType()) &&
746 "expect extent tensor cast");
748 tensor::CastOp::create(rewriter, loc, op.getType(),
replacement);
757struct BroadcastFoldConstantOperandsPattern
759 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
761 LogicalResult matchAndRewrite(BroadcastOp op,
762 PatternRewriter &rewriter)
const override {
763 SmallVector<int64_t, 8> foldedConstantShape;
764 SmallVector<Value, 8> newShapeOperands;
765 for (Value shape : op.getShapes()) {
766 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
767 SmallVector<int64_t, 8> newFoldedConstantShape;
770 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
771 newFoldedConstantShape)) {
772 foldedConstantShape = newFoldedConstantShape;
776 newShapeOperands.push_back(shape);
780 if (op.getNumOperands() - newShapeOperands.size() < 2)
783 auto foldedConstantOperandsTy = RankedTensorType::get(
784 {
static_cast<int64_t
>(foldedConstantShape.size())},
786 newShapeOperands.push_back(
787 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
795template <
typename OpTy>
796struct CanonicalizeCastExtentTensorOperandsPattern
798 using OpRewritePattern<OpTy>::OpRewritePattern;
800 LogicalResult matchAndRewrite(OpTy op,
801 PatternRewriter &rewriter)
const override {
803 bool anyChange =
false;
804 auto canonicalizeOperand = [&](Value operand) -> Value {
805 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
807 bool isInformationLoosingCast =
808 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
809 if (isInformationLoosingCast) {
811 return castOp.getSource();
817 llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand);
827struct BroadcastConcretizeResultTypePattern
829 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
831 LogicalResult matchAndRewrite(BroadcastOp op,
832 PatternRewriter &rewriter)
const override {
834 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
835 if (!resultTy || !resultTy.isDynamicDim(0))
840 for (Value shape : op.getShapes()) {
841 if (
auto extentTensorTy =
842 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
845 if (extentTensorTy.isDynamicDim(0))
847 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
851 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
862 patterns.add<BroadcastConcretizeResultTypePattern,
863 BroadcastFoldConstantOperandsPattern,
864 BroadcastForwardSingleOperandPattern,
865 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
866 RemoveDuplicateOperandsPattern<BroadcastOp>,
867 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
875 if (!adaptor.getLhs() || !adaptor.getRhs())
877 auto lhsShape = llvm::to_vector<6>(
878 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<
int64_t>());
879 auto rhsShape = llvm::to_vector<6>(
880 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<
int64_t>());
882 resultShape.append(lhsShape.begin(), lhsShape.end());
883 resultShape.append(rhsShape.begin(), rhsShape.end());
896 interleaveComma(
getShape().getValues<int64_t>(), p);
911 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
916 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
919 ints.push_back(attr.getInt());
926 result.types.push_back(resultTy);
930OpFoldResult ConstShapeOp::fold(FoldAdaptor) {
return getShapeAttr(); }
934 patterns.add<TensorCastConstShape>(context);
937LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
938 MLIRContext *context, std::optional<Location> location,
941 const Properties prop = adaptor.getProperties();
942 inferredReturnTypes.assign({RankedTensorType::get(
943 {
static_cast<int64_t>(prop.shape.size())},
b.getIndexType())});
947bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
949 if (l.size() != 1 || r.size() != 1)
955 if (llvm::isa<ShapeType>(
lhs) || llvm::isa<ShapeType>(
rhs))
965void CstrBroadcastableOp::getCanonicalizationPatterns(
970 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
971 CstrBroadcastableEqOps,
972 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
973 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
979 bool nonScalarSeen =
false;
981 if (!a || llvm::cast<DenseIntElementsAttr>(a).
getNumElements() != 0) {
984 nonScalarSeen =
true;
990OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
997 for (
const auto &operand : adaptor.getShapes()) {
1000 extents.push_back(llvm::to_vector<6>(
1001 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1011 for (
auto shapeValue : getShapes()) {
1012 extents.emplace_back();
1025LogicalResult CstrBroadcastableOp::verify() {
1027 if (getNumOperands() < 2)
1028 return emitOpError(
"required at least 2 input shapes");
1039 patterns.add<CstrEqEqOps>(context);
1043 if (llvm::all_of(adaptor.getShapes(), [&](
Attribute a) {
1044 return a && a == adaptor.getShapes().front();
1063OpFoldResult ConstSizeOp::fold(FoldAdaptor) {
return getValueAttr(); }
1065void ConstSizeOp::getAsmResultNames(
1068 llvm::raw_svector_ostream os(buffer);
1069 os <<
"c" << getValue();
1070 setNameFn(getResult(), os.str());
1077OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {
return getPassingAttr(); }
1084 return adaptor.getPred();
1091std::optional<int64_t> DimOp::getConstantIndex() {
1092 if (
auto constSizeOp =
getIndex().getDefiningOp<ConstSizeOp>())
1093 return constSizeOp.getValue().getLimitedValue();
1094 if (
auto constantOp =
getIndex().getDefiningOp<arith::ConstantOp>())
1095 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1096 return std::nullopt;
1100 Type valType = getValue().getType();
1101 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1102 if (!valShapedType || !valShapedType.hasRank())
1104 std::optional<int64_t>
index = getConstantIndex();
1105 if (!
index.has_value())
1107 if (
index.value() < 0 ||
index.value() >= valShapedType.getRank())
1109 auto extent = valShapedType.getDimSize(*
index);
1110 if (ShapedType::isDynamic(extent))
1112 return IntegerAttr::get(IndexType::get(
getContext()), extent);
1115LogicalResult mlir::shape::DimOp::inferReturnTypes(
1116 MLIRContext *context, std::optional<Location> location,
1118 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1131 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1134 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1135 if (!
rhs ||
rhs.getValue().isZero())
1140 APInt quotient, remainder;
1141 APInt::sdivrem(
lhs.getValue(),
rhs.getValue(), quotient, remainder);
1142 if (quotient.isNegative() && !remainder.isZero()) {
1147 return IntegerAttr::get(indexTy, quotient);
1150LogicalResult mlir::shape::DivOp::inferReturnTypes(
1151 MLIRContext *context, std::optional<Location> location,
1153 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1154 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1155 inferredReturnTypes.assign({SizeType::get(context)});
1157 inferredReturnTypes.assign({IndexType::get(context)});
1173 bool allSame =
true;
1174 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1176 for (
Attribute operand : adaptor.getShapes().drop_front()) {
1179 allSame = allSame && operand == adaptor.getShapes().front();
1198 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1207 for (
Attribute attr : adaptor.getExtents()) {
1208 auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr);
1211 extents.push_back(intAttr.getInt());
1227FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1228 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1232 return lookupSymbol<FuncOp>(attr);
1235ParseResult FunctionLibraryOp::parse(
OpAsmParser &parser,
1238 StringAttr nameAttr;
1246 auto *bodyRegion =
result.addRegion();
1253 DictionaryAttr mappingAttr;
1265 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(),
"mapping"});
1277FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1281 FuncOp::build(builder, state, name, type, attrs);
1284FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1289FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1292 FuncOp
func = create(location, name, type, attrs);
1293 func.setAllArgAttrs(argAttrs);
1303 TypeAttr::get(type));
1307 if (argAttrs.empty())
1309 assert(type.getNumInputs() == argAttrs.size());
1311 builder, state, argAttrs, {},
1312 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
1316 auto buildFuncType =
1319 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1323 getFunctionTypeAttrName(
result.name), buildFuncType,
1324 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
1329 p, *
this,
false, getFunctionTypeAttrName(),
1330 getArgAttrsAttrName(), getResAttrsAttrName());
1337std::optional<int64_t> GetExtentOp::getConstantDim() {
1338 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1339 return constSizeOp.getValue().getLimitedValue();
1340 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1341 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1342 return std::nullopt;
1346 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1349 std::optional<int64_t> dim = getConstantDim();
1350 if (!dim.has_value())
1352 if (dim.value() >= elements.getNumElements())
1354 return elements.getValues<
Attribute>()[(uint64_t)dim.value()];
1359 auto loc =
result.location;
1361 if (llvm::isa<ShapeType>(
shape.getType())) {
1362 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1371LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1372 MLIRContext *context, std::optional<Location> location,
1374 inferredReturnTypes.assign({IndexType::get(context)});
1378bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1392 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1395OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1397 if (adaptor.getShapes().size() < 2) {
1408LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1409 MLIRContext *context, std::optional<Location> location,
1411 if (adaptor.getOperands().empty())
1414 auto isShapeType = [](
Type arg) {
1415 if (llvm::isa<ShapeType>(arg))
1422 for (
auto t : drop_begin(types)) {
1424 if (!llvm::isa<ShapeType, SizeType>(l))
1428 if (llvm::isa<SizeType>(l)) {
1429 if (llvm::isa<SizeType, IndexType>(r))
1433 }
else if (llvm::isa<IndexType>(l)) {
1434 if (llvm::isa<IndexType>(r))
1438 }
else if (llvm::isa<ShapeType>(l)) {
1445 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1446 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1447 if (ShapedType::isDynamic(rank1))
1449 else if (ShapedType::isDynamic(rank2))
1451 else if (rank1 != rank2)
1457 inferredReturnTypes.assign({
acc});
1462 if (l.size() != 1 || r.size() != 1)
1470 if (!llvm::isa<ShapeType, SizeType>(
lhs))
1473 if (llvm::isa<SizeType>(
lhs))
1474 return llvm::isa<SizeType, IndexType>(
rhs);
1475 if (llvm::isa<ShapeType>(
lhs))
1476 return llvm::isa<ShapeType, TensorType>(
rhs);
1488 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1511struct RankShapeOfCanonicalizationPattern
1513 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1515 LogicalResult matchAndRewrite(shape::RankOp op,
1516 PatternRewriter &rewriter)
const override {
1517 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1520 auto rankedTensorType =
1521 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1522 if (!rankedTensorType)
1524 int64_t rank = rankedTensorType.getRank();
1525 if (llvm::isa<IndexType>(op.getType())) {
1528 }
else if (llvm::isa<shape::SizeType>(op.getType())) {
1540 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1543LogicalResult mlir::shape::RankOp::inferReturnTypes(
1544 MLIRContext *context, std::optional<Location> location,
1546 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1547 inferredReturnTypes.assign({SizeType::get(context)});
1549 inferredReturnTypes.assign({IndexType::get(context)});
1572 for (
auto value : llvm::cast<DenseIntElementsAttr>(
shape))
1578LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1579 MLIRContext *context, std::optional<Location> location,
1580 NumElementsOp::Adaptor adaptor,
1582 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1583 inferredReturnTypes.assign({SizeType::get(context)});
1585 inferredReturnTypes.assign({IndexType::get(context)});
1589bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1595LogicalResult shape::NumElementsOp::verify() {
1605 if (getLhs() == getRhs())
1610LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1611 MLIRContext *context, std::optional<Location> location,
1613 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1614 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1616 inferredReturnTypes.assign({SizeType::get(context)});
1621 if (l.size() != 1 || r.size() != 1)
1623 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1625 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1636 if (getLhs() == getRhs())
1641LogicalResult mlir::shape::MinOp::inferReturnTypes(
1642 MLIRContext *context, std::optional<Location> location,
1644 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1645 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1647 inferredReturnTypes.assign({SizeType::get(context)});
1652 if (l.size() != 1 || r.size() != 1)
1654 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1656 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1666 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1669 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1672 APInt folded =
lhs.getValue() *
rhs.getValue();
1674 return IntegerAttr::get(indexTy, folded);
1677LogicalResult mlir::shape::MulOp::inferReturnTypes(
1678 MLIRContext *context, std::optional<Location> location,
1680 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1681 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1682 inferredReturnTypes.assign({SizeType::get(context)});
1684 inferredReturnTypes.assign({IndexType::get(context)});
1701struct ShapeOfOpToConstShapeOp :
public OpRewritePattern<shape::ShapeOfOp> {
1702 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1704 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1705 PatternRewriter &rewriter)
const override {
1706 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1707 if (!type || !type.hasStaticShape())
1710 Type resultType = op.getResult().getType();
1711 Location loc = op.getLoc();
1713 isa<ShapeType>(resultType)
1715 : RankedTensorType::get({type.getRank()}, rewriter.
getIndexType());
1717 ConstShapeOp::create(rewriter, loc, constResType,
1720 if (constShape.
getType() != resultType)
1722 tensor::CastOp::create(rewriter, loc, resultType, constShape);
1739 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1741 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1742 PatternRewriter &rewriter)
const override {
1743 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1744 if (!tensorReshapeOp)
1746 if (!isa<TensorType>(op.getType()))
1757 Value shape = tensorReshapeOp.getShape();
1759 auto opTensorTy = cast<RankedTensorType>(op.getType());
1760 auto shapeTensorTy = cast<RankedTensorType>(shape.
getType());
1762 if (opTensorTy != shapeTensorTy) {
1763 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1765 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1767 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1786 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1788 LogicalResult matchAndRewrite(tensor::CastOp op,
1789 PatternRewriter &rewriter)
const override {
1790 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1791 if (!ty || ty.getRank() != 1)
1794 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1799 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1800 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1811 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1812 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1816LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1817 MLIRContext *context, std::optional<Location> location,
1819 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1820 inferredReturnTypes.assign({ShapeType::get(context)});
1822 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1824 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1825 Type indexTy = IndexType::get(context);
1826 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1827 inferredReturnTypes.assign({extentTensorTy});
1833 if (l.size() != 1 || r.size() != 1)
1841 if (!llvm::isa<ShapeType, ShapedType>(
lhs) ||
1842 !llvm::isa<ShapeType, ShapedType>(
rhs))
1845 if (llvm::isa<ShapeType>(
lhs) || llvm::isa<ShapeType>(
rhs))
1854LogicalResult shape::ShapeOfOp::verify() {
1872 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1876 if (inputs.size() != 1 || outputs.size() != 1)
1878 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1879 llvm::isa<IndexType>(outputs[0]);
1886LogicalResult shape::YieldOp::verify() {
1887 auto *parentOp = (*this)->getParentOp();
1888 auto results = parentOp->getResults();
1889 auto operands = getOperands();
1891 if (parentOp->getNumResults() != getNumOperands())
1892 return emitOpError() <<
"number of operands does not match number of "
1893 "results of its parent";
1894 for (
auto e : llvm::zip(results, operands))
1896 return emitOpError() <<
"types mismatch between yield op and its parent";
1905LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1907 if (!adaptor.getOperand() || !adaptor.getIndex())
1909 auto shapeVec = llvm::to_vector<6>(
1910 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<
int64_t>());
1912 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1916 if (-rank > splitPoint || splitPoint > rank)
1919 splitPoint +=
shape.size();
1920 Builder builder(adaptor.getOperand().getContext());
1930OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1931 if (!adaptor.getInput())
1934 auto shape = llvm::to_vector<6>(
1935 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<
int64_t>());
1936 auto type = RankedTensorType::get({
static_cast<int64_t>(
shape.size())},
1942 if (inputs.size() != 1 || outputs.size() != 1)
1944 if (
auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1945 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1946 inputTensor.getRank() != 1)
1948 }
else if (!llvm::isa<ShapeType>(inputs[0])) {
1952 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1953 return outputTensor && llvm::isa<IndexType>(outputTensor.
getElementType());
1964 result.addOperands(initVals);
1971 if (
auto tensorType = llvm::dyn_cast<TensorType>(
shape.getType()))
1972 elementType = tensorType.getElementType();
1974 elementType = SizeType::get(builder.
getContext());
1977 for (
Value initVal : initVals) {
1978 bodyBlock->
addArgument(initVal.getType(), initVal.getLoc());
1979 result.addTypes(initVal.getType());
1983LogicalResult ReduceOp::verify() {
1988 auto blockArgsCount = getInitVals().size() + 2;
1990 return emitOpError() <<
"ReduceOp body is expected to have "
1991 << blockArgsCount <<
" arguments";
1996 "argument 0 of ReduceOp body is expected to be of IndexType");
2003 if (!llvm::isa<SizeType>(extentTy))
2004 return emitOpError(
"argument 1 of ReduceOp body is expected to be of "
2005 "SizeType if the ReduceOp operates on a ShapeType");
2007 if (!llvm::isa<IndexType>(extentTy))
2009 "argument 1 of ReduceOp body is expected to be of IndexType if the "
2010 "ReduceOp operates on an extent tensor");
2013 for (
const auto &type : llvm::enumerate(getInitVals()))
2015 return emitOpError() <<
"type mismatch between argument "
2017 <<
" of ReduceOp body and initial value "
2025 Type shapeOrExtentTensorType;
2034 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
2053 p <<
'(' <<
getShape() <<
", " << getInitVals()
2061#define GET_OP_CLASSES
2062#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2064#define GET_TYPEDEF_CLASSES
2065#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isErrorPropagationPossible(TypeRange operandTypes)
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
static LogicalResult verifySizeOrIndexOp(Operation *op)
static int64_t product(ArrayRef< int64_t > vals)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
iterator_range< dialect_attr_iterator > dialect_attr_range
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
ValueTypeRange< ValueRange > type_range
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
bool isExtentTensorType(Type)
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.