25#include "llvm/ADT/SetOperations.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/raw_ostream.h"
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
37#include "ShapeCanonicalization.inc"
41 return RankedTensorType::get({rank}, IndexType::get(ctx));
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
55 llvm::append_range(shapeValues, type.getShape());
60 llvm::append_range(shapeValues, attr.getValues<
int64_t>());
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
75 if (!llvm::isa<SizeType>(resultTy))
77 <<
"if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
87 if (!llvm::isa<ShapeType>(resultTy))
89 <<
"if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
95template <
typename... Ty>
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
100template <
typename... Ty,
typename... ranges>
111struct ShapeInlinerInterface :
public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
117 IRMapping &)
const final {
125 IRMapping &)
const final {
131void ShapeDialect::initialize() {
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
140 addInterfaces<ShapeInlinerInterface>();
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
152 if (
auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
165 return arith::ConstantOp::materialize(builder, value, type, loc);
168LogicalResult ShapeDialect::verifyOperationAttribute(
Operation *op,
171 if (attribute.
getName() ==
"shape.lib") {
174 "shape.lib attribute may only be on op implementing SymbolTable");
176 if (
auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.
getValue())) {
179 return op->
emitError(
"shape function library ")
180 << symbolRef <<
" not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
184 << symbolRef <<
" required to be shape function library";
187 if (
auto arr = llvm::dyn_cast<ArrayAttr>(attribute.
getValue())) {
191 for (
auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
194 "only SymbolRefAttr allowed in shape.lib attribute array");
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
200 << it <<
" does not refer to FunctionLibraryOp";
201 for (
auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->
emitError(
"only one op to shape mapping allowed, found "
205 << mapping.getName() <<
"`";
212 return op->
emitError(
"only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
237 result.regions.reserve(1);
254 AssumingOp::ensureTerminator(*doRegion, parser.
getBuilder(),
result.location);
263 bool yieldsResults = !getResults().empty();
265 p <<
" " << getWitness();
267 p <<
" -> (" << getResultTypes() <<
")";
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter)
const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter)
const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->
getTerminator());
300 SmallVector<Value, 4> newYieldOperands;
301 for (
auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (
auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(
nullptr);
329 replacementValues.push_back(*src++);
331 rewriter.
replaceOp(op, replacementValues);
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
343void AssumingOp::getSuccessorRegions(
360void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
363 auto *assumingBlock = op.getBody();
365 auto *blockAfterAssuming =
366 rewriter.
splitBlock(blockBeforeAssuming, initPosition);
369 auto &yieldOp = assumingBlock->
back();
371 rewriter.
replaceOp(op, yieldOp.getOperands());
376 rewriter.
mergeBlocks(assumingBlock, blockBeforeAssuming);
377 rewriter.
mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
380void AssumingOp::build(
385 result.addOperands(witness);
391 AssumingYieldOp::create(builder,
result.location, yieldValues);
394 for (
Value v : yieldValues)
395 assumingTypes.push_back(v.getType());
396 result.addTypes(assumingTypes);
403LogicalResult mlir::shape::AddOp::inferReturnTypes(
404 MLIRContext *context, std::optional<Location> location,
406 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
407 llvm::isa<SizeType>(adaptor.getRhs().getType()))
408 inferredReturnTypes.assign({SizeType::get(context)});
410 inferredReturnTypes.assign({IndexType::get(context)});
419OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
425 adaptor.getOperands(),
426 [](APInt a,
const APInt &
b) { return std::move(a) + b; });
446 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
448 LogicalResult matchAndRewrite(AssumingAllOp op,
449 PatternRewriter &rewriter)
const override {
450 SmallVector<Value> operands;
452 for (Value operand : op.getInputs()) {
453 if (
auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
454 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
456 operands.push_back(operand);
460 if (operands.size() == op.getNumOperands())
490struct AssumingAllOfCstrBroadcastable :
public OpRewritePattern<AssumingAllOp> {
491 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
493 LogicalResult matchAndRewrite(AssumingAllOp op,
494 PatternRewriter &rewriter)
const override {
497 for (Value operand : op.getInputs()) {
500 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
504 operands.insert(broadcastable);
508 if (operands.size() <= 1)
512 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
513 for (
auto cstr : operands) {
515 shapes.emplace_back(cstr, std::move(shapesSet));
519 llvm::sort(shapes, [](
auto a,
auto b) {
520 return a.first.getNumOperands() >
b.first.getNumOperands();
527 SmallVector<CstrBroadcastableOp> markedForErase;
529 for (
unsigned i = 0; i < shapes.size(); ++i) {
530 auto isSubset = [&](
auto pair) {
531 return llvm::set_is_subset(pair.second, shapes[i].second);
535 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
536 for (
auto *it0 = it; it0 < shapes.end(); ++it0)
537 markedForErase.push_back(it0->first);
538 shapes.erase(it, shapes.end());
542 if (markedForErase.empty())
546 SmallVector<Value> uniqueConstraints;
547 for (
auto &shape : shapes)
548 uniqueConstraints.push_back(shape.first.getResult());
554 for (
auto &op : markedForErase)
562struct AssumingAllToCstrEqCanonicalization
564 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
566 LogicalResult matchAndRewrite(AssumingAllOp op,
567 PatternRewriter &rewriter)
const override {
568 SmallVector<Value, 8> shapes;
569 for (Value w : op.getInputs()) {
570 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
573 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
574 return llvm::is_contained(shapes, s);
576 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
578 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
585template <
typename OpTy>
587 using OpRewritePattern<OpTy>::OpRewritePattern;
589 LogicalResult matchAndRewrite(OpTy op,
590 PatternRewriter &rewriter)
const override {
595 if (unique.size() < op.getNumOperands()) {
597 unique.takeVector(), op->getAttrs());
609 .add<MergeAssumingAllOps, AssumingAllOneOp,
610 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
611 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
617 for (
int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
625 getOperation()->eraseOperand(idx);
628 if (!llvm::cast<BoolAttr>(a).getValue())
635LogicalResult AssumingAllOp::verify() {
637 if (getNumOperands() == 0)
648 if (getShapes().size() == 1) {
652 return getShapes().front();
655 if (!adaptor.getShapes().front())
659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
662 for (
auto next : adaptor.getShapes().drop_front()) {
665 auto nextShape = llvm::to_vector<6>(
666 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
675 std::copy(tmpShape.begin(), tmpShape.end(),
676 std::back_inserter(resultShape));
683LogicalResult BroadcastOp::verify() {
688template <
typename OpTy>
690 using OpRewritePattern<OpTy>::OpRewritePattern;
692 LogicalResult matchAndRewrite(OpTy op,
693 PatternRewriter &rewriter)
const override {
694 auto isPotentiallyNonEmptyShape = [](Value shape) {
695 if (
auto extentTensorTy =
696 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
697 if (extentTensorTy.getDimSize(0) == 0)
700 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
701 if (constShape.getShape().empty())
706 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
707 isPotentiallyNonEmptyShape);
711 if (newOperands.empty()) {
718 if (newOperands.size() < op.getNumOperands()) {
728struct BroadcastForwardSingleOperandPattern
730 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
732 LogicalResult matchAndRewrite(BroadcastOp op,
733 PatternRewriter &rewriter)
const override {
734 if (op.getNumOperands() != 1)
740 auto loc = op.getLoc();
741 if (llvm::isa<ShapeType>(op.getType())) {
744 assert(!llvm::isa<ShapeType>(op.getType()) &&
746 "expect extent tensor cast");
748 tensor::CastOp::create(rewriter, loc, op.getType(),
replacement);
757struct BroadcastFoldConstantOperandsPattern
759 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
761 LogicalResult matchAndRewrite(BroadcastOp op,
762 PatternRewriter &rewriter)
const override {
763 SmallVector<int64_t, 8> foldedConstantShape;
764 SmallVector<Value, 8> newShapeOperands;
765 for (Value shape : op.getShapes()) {
766 if (
auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
767 SmallVector<int64_t, 8> newFoldedConstantShape;
770 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
771 newFoldedConstantShape)) {
772 foldedConstantShape = newFoldedConstantShape;
776 newShapeOperands.push_back(shape);
780 if (op.getNumOperands() - newShapeOperands.size() < 2)
783 auto foldedConstantOperandsTy = RankedTensorType::get(
784 {
static_cast<int64_t
>(foldedConstantShape.size())},
786 newShapeOperands.push_back(
787 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
795template <
typename OpTy>
796struct CanonicalizeCastExtentTensorOperandsPattern
798 using OpRewritePattern<OpTy>::OpRewritePattern;
800 LogicalResult matchAndRewrite(OpTy op,
801 PatternRewriter &rewriter)
const override {
803 bool anyChange =
false;
804 auto canonicalizeOperand = [&](Value operand) -> Value {
805 if (
auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
807 bool isInformationLoosingCast =
808 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
809 if (isInformationLoosingCast) {
811 return castOp.getSource();
817 llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand);
827struct BroadcastConcretizeResultTypePattern
829 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
831 LogicalResult matchAndRewrite(BroadcastOp op,
832 PatternRewriter &rewriter)
const override {
834 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
835 if (!resultTy || !resultTy.isDynamicDim(0))
840 for (Value shape : op.getShapes()) {
841 if (
auto extentTensorTy =
842 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
845 if (extentTensorTy.isDynamicDim(0))
847 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
851 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
862 patterns.add<BroadcastConcretizeResultTypePattern,
863 BroadcastFoldConstantOperandsPattern,
864 BroadcastForwardSingleOperandPattern,
865 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
866 RemoveDuplicateOperandsPattern<BroadcastOp>,
867 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
875 if (!adaptor.getLhs() || !adaptor.getRhs())
877 auto lhsShape = llvm::to_vector<6>(
878 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<
int64_t>());
879 auto rhsShape = llvm::to_vector<6>(
880 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<
int64_t>());
882 resultShape.append(lhsShape.begin(), lhsShape.end());
883 resultShape.append(rhsShape.begin(), rhsShape.end());
896 interleaveComma(
getShape().getValues<int64_t>(), p);
911 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
916 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
919 ints.push_back(attr.getInt());
926 result.types.push_back(resultTy);
930OpFoldResult ConstShapeOp::fold(FoldAdaptor) {
return getShapeAttr(); }
934 patterns.add<TensorCastConstShape>(context);
937LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
938 MLIRContext *context, std::optional<Location> location,
941 const Properties prop = adaptor.getProperties();
942 inferredReturnTypes.assign({RankedTensorType::get(
943 {
static_cast<int64_t>(prop.shape.size())},
b.getIndexType())});
947bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(
TypeRange l,
949 if (l.size() != 1 || r.size() != 1)
955 if (llvm::isa<ShapeType>(
lhs) || llvm::isa<ShapeType>(
rhs))
965void CstrBroadcastableOp::getCanonicalizationPatterns(
970 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
971 CstrBroadcastableEqOps,
972 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
973 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
979 bool nonScalarSeen =
false;
981 if (!a || llvm::cast<DenseIntElementsAttr>(a).
getNumElements() != 0) {
984 nonScalarSeen =
true;
990OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
997 for (
const auto &operand : adaptor.getShapes()) {
1000 extents.push_back(llvm::to_vector<6>(
1001 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1011 for (
auto shapeValue : getShapes()) {
1012 extents.emplace_back();
1025LogicalResult CstrBroadcastableOp::verify() {
1027 if (getNumOperands() < 2)
1028 return emitOpError(
"required at least 2 input shapes");
1039 patterns.add<CstrEqEqOps>(context);
1043 if (llvm::all_of(adaptor.getShapes(), [&](
Attribute a) {
1044 return a && a == adaptor.getShapes().front();
1063OpFoldResult ConstSizeOp::fold(FoldAdaptor) {
return getValueAttr(); }
1065void ConstSizeOp::getAsmResultNames(
1068 llvm::raw_svector_ostream os(buffer);
1069 os <<
"c" << getValue();
1070 setNameFn(getResult(), os.str());
1077OpFoldResult ConstWitnessOp::fold(FoldAdaptor) {
return getPassingAttr(); }
1084 return adaptor.getPred();
1091std::optional<int64_t> DimOp::getConstantIndex() {
1092 if (
auto constSizeOp =
getIndex().getDefiningOp<ConstSizeOp>())
1093 return constSizeOp.getValue().getLimitedValue();
1094 if (
auto constantOp =
getIndex().getDefiningOp<arith::ConstantOp>())
1095 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1096 return std::nullopt;
1100 Type valType = getValue().getType();
1101 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1102 if (!valShapedType || !valShapedType.hasRank())
1104 std::optional<int64_t>
index = getConstantIndex();
1105 if (!
index.has_value())
1107 if (
index.value() < 0 ||
index.value() >= valShapedType.getRank())
1109 auto extent = valShapedType.getDimSize(*
index);
1110 if (ShapedType::isDynamic(extent))
1112 return IntegerAttr::get(IndexType::get(
getContext()), extent);
1115LogicalResult mlir::shape::DimOp::inferReturnTypes(
1116 MLIRContext *context, std::optional<Location> location,
1118 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1131 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1134 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1135 if (!
rhs ||
rhs.getValue().isZero())
1140 APInt quotient, remainder;
1141 APInt::sdivrem(
lhs.getValue(),
rhs.getValue(), quotient, remainder);
1142 if (quotient.isNegative() && !remainder.isZero()) {
1147 return IntegerAttr::get(indexTy, quotient);
1150LogicalResult mlir::shape::DivOp::inferReturnTypes(
1151 MLIRContext *context, std::optional<Location> location,
1153 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1154 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1155 inferredReturnTypes.assign({SizeType::get(context)});
1157 inferredReturnTypes.assign({IndexType::get(context)});
1173 bool allSame =
true;
1174 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1176 for (
Attribute operand : adaptor.getShapes().drop_front()) {
1179 allSame = allSame && operand == adaptor.getShapes().front();
1198 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1206 if (llvm::any_of(adaptor.getExtents(), [](
Attribute a) { return !a; }))
1209 for (
auto attr : adaptor.getExtents())
1210 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1225FuncOp FunctionLibraryOp::getShapeFunction(
Operation *op) {
1226 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1230 return lookupSymbol<FuncOp>(attr);
1233ParseResult FunctionLibraryOp::parse(
OpAsmParser &parser,
1236 StringAttr nameAttr;
1244 auto *bodyRegion =
result.addRegion();
1251 DictionaryAttr mappingAttr;
1263 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(),
"mapping"});
1275FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1279 FuncOp::build(builder, state, name, type, attrs);
1282FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1287FuncOp FuncOp::create(
Location location, StringRef name, FunctionType type,
1290 FuncOp
func = create(location, name, type, attrs);
1291 func.setAllArgAttrs(argAttrs);
1301 TypeAttr::get(type));
1305 if (argAttrs.empty())
1307 assert(type.getNumInputs() == argAttrs.size());
1309 builder, state, argAttrs, {},
1310 getArgAttrsAttrName(state.
name), getResAttrsAttrName(state.
name));
1314 auto buildFuncType =
1317 std::string &) {
return builder.
getFunctionType(argTypes, results); };
1321 getFunctionTypeAttrName(
result.name), buildFuncType,
1322 getArgAttrsAttrName(
result.name), getResAttrsAttrName(
result.name));
1327 p, *
this,
false, getFunctionTypeAttrName(),
1328 getArgAttrsAttrName(), getResAttrsAttrName());
1335std::optional<int64_t> GetExtentOp::getConstantDim() {
1336 if (
auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1337 return constSizeOp.getValue().getLimitedValue();
1338 if (
auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1339 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1340 return std::nullopt;
1344 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1347 std::optional<int64_t> dim = getConstantDim();
1348 if (!dim.has_value())
1350 if (dim.value() >= elements.getNumElements())
1352 return elements.getValues<
Attribute>()[(uint64_t)dim.value()];
1357 auto loc =
result.location;
1359 if (llvm::isa<ShapeType>(
shape.getType())) {
1360 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1369LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1372 inferredReturnTypes.assign({IndexType::get(context)});
1376bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(
TypeRange l,
1390 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1393OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1395 if (adaptor.getShapes().size() < 2) {
1406LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1407 MLIRContext *context, std::optional<Location> location,
1409 if (adaptor.getOperands().empty())
1412 auto isShapeType = [](
Type arg) {
1413 if (llvm::isa<ShapeType>(arg))
1420 for (
auto t : drop_begin(types)) {
1422 if (!llvm::isa<ShapeType, SizeType>(l))
1426 if (llvm::isa<SizeType>(l)) {
1427 if (llvm::isa<SizeType, IndexType>(r))
1431 }
else if (llvm::isa<IndexType>(l)) {
1432 if (llvm::isa<IndexType>(r))
1436 }
else if (llvm::isa<ShapeType>(l)) {
1443 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1444 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1445 if (ShapedType::isDynamic(rank1))
1447 else if (ShapedType::isDynamic(rank2))
1449 else if (rank1 != rank2)
1455 inferredReturnTypes.assign({
acc});
1460 if (l.size() != 1 || r.size() != 1)
1468 if (!llvm::isa<ShapeType, SizeType>(
lhs))
1471 if (llvm::isa<SizeType>(
lhs))
1472 return llvm::isa<SizeType, IndexType>(
rhs);
1473 if (llvm::isa<ShapeType>(
lhs))
1474 return llvm::isa<ShapeType, TensorType>(
rhs);
1486 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1509struct RankShapeOfCanonicalizationPattern
1511 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1513 LogicalResult matchAndRewrite(shape::RankOp op,
1514 PatternRewriter &rewriter)
const override {
1515 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1518 auto rankedTensorType =
1519 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1520 if (!rankedTensorType)
1522 int64_t rank = rankedTensorType.getRank();
1523 if (llvm::isa<IndexType>(op.getType())) {
1526 }
else if (llvm::isa<shape::SizeType>(op.getType())) {
1538 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1541LogicalResult mlir::shape::RankOp::inferReturnTypes(
1542 MLIRContext *context, std::optional<Location> location,
1544 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1545 inferredReturnTypes.assign({SizeType::get(context)});
1547 inferredReturnTypes.assign({IndexType::get(context)});
1570 for (
auto value : llvm::cast<DenseIntElementsAttr>(
shape))
1576LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1577 MLIRContext *context, std::optional<Location> location,
1578 NumElementsOp::Adaptor adaptor,
1580 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1581 inferredReturnTypes.assign({SizeType::get(context)});
1583 inferredReturnTypes.assign({IndexType::get(context)});
1587bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(
TypeRange l,
1593LogicalResult shape::NumElementsOp::verify() {
1603 if (getLhs() == getRhs())
1608LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1609 MLIRContext *context, std::optional<Location> location,
1611 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1612 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1614 inferredReturnTypes.assign({SizeType::get(context)});
1619 if (l.size() != 1 || r.size() != 1)
1621 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1623 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1634 if (getLhs() == getRhs())
1639LogicalResult mlir::shape::MinOp::inferReturnTypes(
1640 MLIRContext *context, std::optional<Location> location,
1642 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1643 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1645 inferredReturnTypes.assign({SizeType::get(context)});
1650 if (l.size() != 1 || r.size() != 1)
1652 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1654 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1664 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1667 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1670 APInt folded =
lhs.getValue() *
rhs.getValue();
1672 return IntegerAttr::get(indexTy, folded);
1675LogicalResult mlir::shape::MulOp::inferReturnTypes(
1676 MLIRContext *context, std::optional<Location> location,
1678 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1679 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1680 inferredReturnTypes.assign({SizeType::get(context)});
1682 inferredReturnTypes.assign({IndexType::get(context)});
1699struct ShapeOfOpToConstShapeOp :
public OpRewritePattern<shape::ShapeOfOp> {
1700 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1702 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1703 PatternRewriter &rewriter)
const override {
1704 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1705 if (!type || !type.hasStaticShape())
1707 Location loc = op.getLoc();
1709 ConstShapeOp::create(rewriter, loc,
1712 if (constShape.
getType() != op.getResult().getType())
1713 constShape = tensor::CastOp::create(rewriter, loc,
1714 op.getResult().getType(), constShape);
1731 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1733 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1734 PatternRewriter &rewriter)
const override {
1735 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1736 if (!tensorReshapeOp)
1738 if (!isa<TensorType>(op.getType()))
1749 Value shape = tensorReshapeOp.getShape();
1751 auto opTensorTy = cast<RankedTensorType>(op.getType());
1752 auto shapeTensorTy = cast<RankedTensorType>(shape.
getType());
1754 if (opTensorTy != shapeTensorTy) {
1755 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1757 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1759 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1778 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1780 LogicalResult matchAndRewrite(tensor::CastOp op,
1781 PatternRewriter &rewriter)
const override {
1782 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1783 if (!ty || ty.getRank() != 1)
1786 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1791 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1792 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1803 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1804 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1808LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1809 MLIRContext *context, std::optional<Location> location,
1811 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1812 inferredReturnTypes.assign({ShapeType::get(context)});
1814 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1816 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1817 Type indexTy = IndexType::get(context);
1818 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1819 inferredReturnTypes.assign({extentTensorTy});
1825 if (l.size() != 1 || r.size() != 1)
1833 if (!llvm::isa<ShapeType, ShapedType>(
lhs) ||
1834 !llvm::isa<ShapeType, ShapedType>(
rhs))
1837 if (llvm::isa<ShapeType>(
lhs) || llvm::isa<ShapeType>(
rhs))
1846LogicalResult shape::ShapeOfOp::verify() {
1864 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1868 if (inputs.size() != 1 || outputs.size() != 1)
1870 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1871 llvm::isa<IndexType>(outputs[0]);
1878LogicalResult shape::YieldOp::verify() {
1879 auto *parentOp = (*this)->getParentOp();
1880 auto results = parentOp->getResults();
1881 auto operands = getOperands();
1883 if (parentOp->getNumResults() != getNumOperands())
1884 return emitOpError() <<
"number of operands does not match number of "
1885 "results of its parent";
1886 for (
auto e : llvm::zip(results, operands))
1888 return emitOpError() <<
"types mismatch between yield op and its parent";
1897LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1899 if (!adaptor.getOperand() || !adaptor.getIndex())
1901 auto shapeVec = llvm::to_vector<6>(
1902 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<
int64_t>());
1904 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1908 if (-rank > splitPoint || splitPoint > rank)
1911 splitPoint +=
shape.size();
1912 Builder builder(adaptor.getOperand().getContext());
1922OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1923 if (!adaptor.getInput())
1926 auto shape = llvm::to_vector<6>(
1927 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<
int64_t>());
1928 auto type = RankedTensorType::get({
static_cast<int64_t>(
shape.size())},
1934 if (inputs.size() != 1 || outputs.size() != 1)
1936 if (
auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1937 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1938 inputTensor.getRank() != 1)
1940 }
else if (!llvm::isa<ShapeType>(inputs[0])) {
1944 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1945 return outputTensor && llvm::isa<IndexType>(outputTensor.
getElementType());
1956 result.addOperands(initVals);
1963 if (
auto tensorType = llvm::dyn_cast<TensorType>(
shape.getType()))
1964 elementType = tensorType.getElementType();
1966 elementType = SizeType::get(builder.
getContext());
1969 for (
Value initVal : initVals) {
1970 bodyBlock->
addArgument(initVal.getType(), initVal.getLoc());
1971 result.addTypes(initVal.getType());
1975LogicalResult ReduceOp::verify() {
1980 auto blockArgsCount = getInitVals().size() + 2;
1982 return emitOpError() <<
"ReduceOp body is expected to have "
1983 << blockArgsCount <<
" arguments";
1988 "argument 0 of ReduceOp body is expected to be of IndexType");
1995 if (!llvm::isa<SizeType>(extentTy))
1996 return emitOpError(
"argument 1 of ReduceOp body is expected to be of "
1997 "SizeType if the ReduceOp operates on a ShapeType");
1999 if (!llvm::isa<IndexType>(extentTy))
2001 "argument 1 of ReduceOp body is expected to be of IndexType if the "
2002 "ReduceOp operates on an extent tensor");
2005 for (
const auto &type : llvm::enumerate(getInitVals()))
2007 return emitOpError() <<
"type mismatch between argument "
2009 <<
" of ReduceOp body and initial value "
2017 Type shapeOrExtentTensorType;
2026 if (parser.
resolveOperand(operands.front(), shapeOrExtentTensorType,
2045 p <<
'(' <<
getShape() <<
", " << getInitVals()
2053#define GET_OP_CLASSES
2054#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2056#define GET_TYPEDEF_CLASSES
2057#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isErrorPropagationPossible(TypeRange operandTypes)
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
static LogicalResult verifySizeOrIndexOp(Operation *op)
static int64_t product(ArrayRef< int64_t > vals)
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
IntegerAttr getIndexAttr(int64_t value)
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
StringAttr getStringAttr(const Twine &bytes)
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
MLIRContext * getContext() const
NamedAttribute getNamedAttr(StringRef name, Attribute val)
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
iterator_range< dialect_attr_iterator > dialect_attr_range
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
operand_type_range getOperandTypes()
result_type_range getResultTypes()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
ValueTypeRange< ValueRange > type_range
Type front()
Return first type in the range.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
bool isExtentTensorType(Type)
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
llvm::function_ref< Fn > function_ref
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.