MLIR 22.0.0git
Shape.cpp
Go to the documentation of this file.
1//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include <utility>
10
12
17#include "mlir/Dialect/Traits.h"
19#include "mlir/IR/Builders.h"
22#include "mlir/IR/Matchers.h"
27#include "llvm/ADT/SetOperations.h"
28#include "llvm/ADT/TypeSwitch.h"
29#include "llvm/Support/raw_ostream.h"
30
31using namespace mlir;
32using namespace mlir::shape;
33
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35
36namespace {
37#include "ShapeCanonicalization.inc"
38} // namespace
39
40RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41 return RankedTensorType::get({rank}, IndexType::get(ctx));
42}
43
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47}
48
49LogicalResult shape::getShapeVec(Value input,
50 SmallVectorImpl<int64_t> &shapeValues) {
51 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
53 if (!type.hasRank())
54 return failure();
55 llvm::append_range(shapeValues, type.getShape());
56 return success();
57 }
59 if (matchPattern(input, m_Constant(&attr))) {
60 llvm::append_range(shapeValues, attr.getValues<int64_t>());
61 return success();
62 }
63 return failure();
64}
65
66static bool isErrorPropagationPossible(TypeRange operandTypes) {
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
69}
70
71static LogicalResult verifySizeOrIndexOp(Operation *op) {
72 assert(op != nullptr && op->getNumResults() == 1);
73 Type resultTy = op->getResultTypes().front();
75 if (!llvm::isa<SizeType>(resultTy))
76 return op->emitOpError()
77 << "if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
79 }
80 return success();
81}
82
83static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84 assert(op != nullptr && op->getNumResults() == 1);
85 Type resultTy = op->getResultTypes().front();
87 if (!llvm::isa<ShapeType>(resultTy))
88 return op->emitOpError()
89 << "if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
91 }
92 return success();
93}
94
95template <typename... Ty>
96static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
98}
99
100template <typename... Ty, typename... ranges>
101static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103}
104
105//===----------------------------------------------------------------------===//
106// InlinerInterface
107//===----------------------------------------------------------------------===//
108
109namespace {
110/// This class defines the interface for inlining shape dialect ops.
111struct ShapeInlinerInterface : public DialectInlinerInterface {
113
114 // Returns true if the given region 'src' can be inlined into the region
115 // 'dest' that is attached to an operation registered to the current dialect.
116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117 IRMapping &) const final {
118 return true;
119 }
120
121 // Returns true if the given operation 'op', that is registered to this
122 // dialect, can be inlined into the region 'dest' that is attached to an
123 // operation registered to the current dialect.
124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125 IRMapping &) const final {
126 return true;
127 }
128};
129} // namespace
130
131void ShapeDialect::initialize() {
132 addOperations<
133#define GET_OP_LIST
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135 >();
136 addTypes<
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139 >();
140 addInterfaces<ShapeInlinerInterface>();
141 // Allow unknown operations during prototyping and testing. As the dialect is
142 // still evolving it makes it simple to start with an unregistered ops and
143 // try different variants before actually defining the op.
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
146 AssumingYieldOp>();
147}
148
149Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
150 Attribute value, Type type,
151 Location loc) {
152 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
154
155 if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
164
165 return arith::ConstantOp::materialize(builder, value, type, loc);
166}
167
168LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
169 NamedAttribute attribute) {
170 // Verify shape.lib attribute.
171 if (attribute.getName() == "shape.lib") {
172 if (!op->hasTrait<OpTrait::SymbolTable>())
173 return op->emitError(
174 "shape.lib attribute may only be on op implementing SymbolTable");
175
176 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
177 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
178 if (!symbol)
179 return op->emitError("shape function library ")
180 << symbolRef << " not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
182 ? success()
183 : op->emitError()
184 << symbolRef << " required to be shape function library";
185 }
186
187 if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
188 // Verify all entries are function libraries and mappings in libraries
189 // refer to unique ops.
191 for (auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
193 return op->emitError(
194 "only SymbolRefAttr allowed in shape.lib attribute array");
195
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
197 SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
198 if (!shapeFnLib)
199 return op->emitError()
200 << it << " does not refer to FunctionLibraryOp";
201 for (auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->emitError("only one op to shape mapping allowed, found "
204 "multiple for `")
205 << mapping.getName() << "`";
206 }
207 }
208 }
209 return success();
210 }
211
212 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
214 }
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// AnyOp
220//===----------------------------------------------------------------------===//
221
222// TODO: Canonicalization should be implemented for shapes that can be
223// determined through mixtures of the known dimensions of the inputs.
224OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
225 // Only the last operand is checked because AnyOp is commutative.
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
228
229 return nullptr;
230}
231
232//===----------------------------------------------------------------------===//
233// AssumingOp
234//===----------------------------------------------------------------------===//
235
236ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
237 result.regions.reserve(1);
238 Region *doRegion = result.addRegion();
239
240 auto &builder = parser.getBuilder();
242 if (parser.parseOperand(cond) ||
243 parser.resolveOperand(cond, builder.getType<WitnessType>(),
244 result.operands))
245 return failure();
246
247 // Parse optional results type list.
248 if (parser.parseOptionalArrowTypeList(result.types))
249 return failure();
250
251 // Parse the region and add a terminator if elided.
252 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
253 return failure();
254 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
255
256 // Parse the optional attribute list.
257 if (parser.parseOptionalAttrDict(result.attributes))
258 return failure();
259 return success();
260}
261
262void AssumingOp::print(OpAsmPrinter &p) {
263 bool yieldsResults = !getResults().empty();
264
265 p << " " << getWitness();
266 if (yieldsResults)
267 p << " -> (" << getResultTypes() << ")";
268 p << ' ';
269 p.printRegion(getDoRegion(),
270 /*printEntryBlockArgs=*/false,
271 /*printBlockTerminators=*/yieldsResults);
272 p.printOptionalAttrDict((*this)->getAttrs());
273}
274
275namespace {
276// Removes AssumingOp with a passing witness and inlines the region.
277struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter) const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
284 return failure();
285
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
287 return success();
288 }
289};
290
291struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
293
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter) const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
298
299 // Find used values.
300 SmallVector<Value, 4> newYieldOperands;
301 for (auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
305 }
306 }
307
308 // Rewrite only if redundant results exist.
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
310 return failure();
311
312 // Replace yield op in the old assuming op's body and move the entire region
313 // to the new assuming op.
314 rewriter.setInsertionPointToEnd(body);
315 auto newYieldOp =
316 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
317 rewriter.setInsertionPoint(op);
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
321
322 // Use the new results to replace the previously used ones.
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(nullptr);
328 else
329 replacementValues.push_back(*src++);
330 }
331 rewriter.replaceOp(op, replacementValues);
332 return success();
333 }
334};
335} // namespace
336
337void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
338 MLIRContext *context) {
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
340}
341
342// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
343void AssumingOp::getSuccessorRegions(
345 // AssumingOp has unconditional control flow into the region and back to the
346 // parent, so return the correct RegionSuccessor purely based on the index
347 // being None or 0.
348 if (!point.isParent()) {
349 regions.push_back(RegionSuccessor(getOperation(), getResults()));
350 return;
351 }
352
353 regions.push_back(RegionSuccessor(&getDoRegion()));
354}
355
356void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
357 PatternRewriter &rewriter) {
358 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
359 auto *assumingBlock = op.getBody();
360 auto initPosition = rewriter.getInsertionPoint();
361 auto *blockAfterAssuming =
362 rewriter.splitBlock(blockBeforeAssuming, initPosition);
363
364 // Remove the AssumingOp and AssumingYieldOp.
365 auto &yieldOp = assumingBlock->back();
366 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
367 rewriter.replaceOp(op, yieldOp.getOperands());
368 rewriter.eraseOp(&yieldOp);
369
370 // Merge blocks together as there was no branching behavior from the
371 // AssumingOp.
372 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
373 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
374}
375
376void AssumingOp::build(
377 OpBuilder &builder, OperationState &result, Value witness,
379 OpBuilder::InsertionGuard g(builder);
380
381 result.addOperands(witness);
382 Region *bodyRegion = result.addRegion();
383 builder.createBlock(bodyRegion);
384
385 // Build body.
386 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
387 AssumingYieldOp::create(builder, result.location, yieldValues);
388
389 SmallVector<Type, 2> assumingTypes;
390 for (Value v : yieldValues)
391 assumingTypes.push_back(v.getType());
392 result.addTypes(assumingTypes);
393}
394
395//===----------------------------------------------------------------------===//
396// AddOp
397//===----------------------------------------------------------------------===//
398
399LogicalResult mlir::shape::AddOp::inferReturnTypes(
400 MLIRContext *context, std::optional<Location> location,
401 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
402 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
403 llvm::isa<SizeType>(adaptor.getRhs().getType()))
404 inferredReturnTypes.assign({SizeType::get(context)});
405 else
406 inferredReturnTypes.assign({IndexType::get(context)});
407 return success();
408}
409
410bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
411 // SizeType is compatible with IndexType.
413}
414
415OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
416 // add(x, 0) -> x
417 if (matchPattern(getRhs(), m_Zero()))
418 return getLhs();
419
421 adaptor.getOperands(),
422 [](APInt a, const APInt &b) { return std::move(a) + b; });
423}
424
425LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
426
427//===----------------------------------------------------------------------===//
428// AssumingAllOp
429//===----------------------------------------------------------------------===//
430
431namespace {
432
433// Merge multiple `shape.assuming_all` operations together.
434//
435// %0 = shape.assuming_all %w0, %w1
436// %1 = shape.assuming_all %w2, %0
437//
438// to:
439//
440// %0 = shape.assuming_all %w0, %w2, %w2
441struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
442 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
443
444 LogicalResult matchAndRewrite(AssumingAllOp op,
445 PatternRewriter &rewriter) const override {
446 SmallVector<Value> operands;
447
448 for (Value operand : op.getInputs()) {
449 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
450 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
451 else
452 operands.push_back(operand);
453 }
454
455 // We didn't find any other `assuming_all` ops to merge with.
456 if (operands.size() == op.getNumOperands())
457 return failure();
458
459 // Replace with a new `assuming_all` operation with merged constraints.
460 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
461 return success();
462 }
463};
464
465// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
466// are subsumed by others.
467//
468// %0 = shape.cstr_broadcastable %shape0, %shape1
469// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
470//
471// %2 = shape.cstr_broadcastable %shape3, %shape4
472// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
473//
474// %4 = shape.assuming_all %0, %1, %2, %3
475//
476// to:
477//
478// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
479// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
480// %2 = shape.assuming_all %0, %1
481//
482// In this example if shapes [0, 1, 2] are broadcastable, then it means that
483// shapes [0, 1] are broadcastable too, and can be removed from the list of
484// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
485// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
486struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
487 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
488
489 LogicalResult matchAndRewrite(AssumingAllOp op,
490 PatternRewriter &rewriter) const override {
491 // Collect all `CstrBroadcastableOp` operands first.
493 for (Value operand : op.getInputs()) {
494 // TODO: Apply this optimization if some of the witnesses are not
495 // produced by the `cstr_broadcastable`.
496 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
497 if (!broadcastable)
498 return failure();
499
500 operands.insert(broadcastable);
501 }
502
503 // Skip trivial `assuming_all` operations.
504 if (operands.size() <= 1)
505 return failure();
506
507 // Collect shapes checked by `cstr_broadcastable` operands.
508 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
509 for (auto cstr : operands) {
510 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
511 shapes.emplace_back(cstr, std::move(shapesSet));
512 }
513
514 // Sort by the number of shape operands (larger to smaller).
515 llvm::sort(shapes, [](auto a, auto b) {
516 return a.first.getNumOperands() > b.first.getNumOperands();
517 });
518
519 // We start from the `cst_broadcastable` operations with largest number of
520 // shape operands, and remove redundant `cst_broadcastable` operations. We
521 // do this until we find a set of `cst_broadcastable` operations with
522 // non-overlapping constraints.
523 SmallVector<CstrBroadcastableOp> markedForErase;
524
525 for (unsigned i = 0; i < shapes.size(); ++i) {
526 auto isSubset = [&](auto pair) {
527 return llvm::set_is_subset(pair.second, shapes[i].second);
528 };
529
530 // Keep redundant `cstr_broadcastable` operations to be erased.
531 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
532 for (auto *it0 = it; it0 < shapes.end(); ++it0)
533 markedForErase.push_back(it0->first);
534 shapes.erase(it, shapes.end());
535 }
536
537 // We didn't find any operands that could be removed.
538 if (markedForErase.empty())
539 return failure();
540
541 // Collect non-overlapping `cst_broadcastable` constraints.
542 SmallVector<Value> uniqueConstraints;
543 for (auto &shape : shapes)
544 uniqueConstraints.push_back(shape.first.getResult());
545
546 // Replace with a new `assuming_all` operation ...
547 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
548
549 // ... and maybe erase `cstr_broadcastable` ops without uses.
550 for (auto &op : markedForErase)
551 if (op->use_empty())
552 rewriter.eraseOp(op);
553
554 return success();
555 }
556};
557
558struct AssumingAllToCstrEqCanonicalization
559 : public OpRewritePattern<AssumingAllOp> {
560 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
561
562 LogicalResult matchAndRewrite(AssumingAllOp op,
563 PatternRewriter &rewriter) const override {
564 SmallVector<Value, 8> shapes;
565 for (Value w : op.getInputs()) {
566 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
567 if (!cstrEqOp)
568 return failure();
569 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
570 return llvm::is_contained(shapes, s);
571 });
572 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
573 return failure();
574 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
575 }
576 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
577 return success();
578 }
579};
580
581template <typename OpTy>
582struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
583 using OpRewritePattern<OpTy>::OpRewritePattern;
584
585 LogicalResult matchAndRewrite(OpTy op,
586 PatternRewriter &rewriter) const override {
587 // Find unique operands.
588 SetVector<Value> unique(op.operand_begin(), op.operand_end());
589
590 // Reduce op to equivalent with unique operands.
591 if (unique.size() < op.getNumOperands()) {
592 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
593 unique.takeVector(), op->getAttrs());
594 return success();
595 }
596
597 return failure();
598 }
599};
600} // namespace
601
602void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
603 MLIRContext *context) {
605 .add<MergeAssumingAllOps, AssumingAllOneOp,
606 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
607 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
608}
609
610OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
611 // Iterate in reverse to first handle all constant operands. They are
612 // guaranteed to be the tail of the inputs because this is commutative.
613 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
614 Attribute a = adaptor.getInputs()[idx];
615 // Cannot fold if any inputs are not constant;
616 if (!a)
617 return nullptr;
618
619 // We do not need to keep statically known values after handling them in
620 // this method.
621 getOperation()->eraseOperand(idx);
622
623 // Always false if any input is statically known false
624 if (!llvm::cast<BoolAttr>(a).getValue())
625 return a;
626 }
627 // If this is reached, all inputs were statically known passing.
628 return BoolAttr::get(getContext(), true);
629}
630
631LogicalResult AssumingAllOp::verify() {
632 // Ensure that AssumingAllOp contains at least one operand
633 if (getNumOperands() == 0)
634 return emitOpError("no operands specified");
635
636 return success();
637}
638
639//===----------------------------------------------------------------------===//
640// BroadcastOp
641//===----------------------------------------------------------------------===//
642
643OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
644 if (getShapes().size() == 1) {
645 // Otherwise, we need a cast which would be a canonicalization, not folding.
646 if (getShapes().front().getType() != getType())
647 return nullptr;
648 return getShapes().front();
649 }
650
651 if (!adaptor.getShapes().front())
652 return nullptr;
653
654 SmallVector<int64_t, 6> resultShape(
655 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
656 .getValues<int64_t>());
657
658 for (auto next : adaptor.getShapes().drop_front()) {
659 if (!next)
660 return nullptr;
661 auto nextShape = llvm::to_vector<6>(
662 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
663
665 // If the shapes are not compatible, we can't fold it.
666 // TODO: Fold to an "error".
667 if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
668 return nullptr;
669
670 resultShape.clear();
671 std::copy(tmpShape.begin(), tmpShape.end(),
672 std::back_inserter(resultShape));
673 }
674
675 Builder builder(getContext());
676 return builder.getIndexTensorAttr(resultShape);
677}
678
679LogicalResult BroadcastOp::verify() {
680 return verifyShapeOrExtentTensorOp(*this);
681}
682
683namespace {
684template <typename OpTy>
685struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
686 using OpRewritePattern<OpTy>::OpRewritePattern;
687
688 LogicalResult matchAndRewrite(OpTy op,
689 PatternRewriter &rewriter) const override {
690 auto isPotentiallyNonEmptyShape = [](Value shape) {
691 if (auto extentTensorTy =
692 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
693 if (extentTensorTy.getDimSize(0) == 0)
694 return false;
695 }
696 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
697 if (constShape.getShape().empty())
698 return false;
699 }
700 return true;
701 };
702 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
703 isPotentiallyNonEmptyShape);
704
705 // Replace the op with empty shape constant if all operants are reduced to
706 // be empty.
707 if (newOperands.empty()) {
708 rewriter.replaceOpWithNewOp<ConstShapeOp>(
709 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
710 return success();
711 }
712
713 // Reduce op to equivalent without empty shape operands.
714 if (newOperands.size() < op.getNumOperands()) {
715 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
716 op->getAttrs());
717 return success();
718 }
719
720 return failure();
721 }
722};
723
724struct BroadcastForwardSingleOperandPattern
725 : public OpRewritePattern<BroadcastOp> {
726 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
727
728 LogicalResult matchAndRewrite(BroadcastOp op,
729 PatternRewriter &rewriter) const override {
730 if (op.getNumOperands() != 1)
731 return failure();
732 Value replacement = op.getShapes().front();
733
734 // Insert cast if needed.
735 if (replacement.getType() != op.getType()) {
736 auto loc = op.getLoc();
737 if (llvm::isa<ShapeType>(op.getType())) {
738 replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
739 } else {
740 assert(!llvm::isa<ShapeType>(op.getType()) &&
741 !llvm::isa<ShapeType>(replacement.getType()) &&
742 "expect extent tensor cast");
744 tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
745 }
746 }
747
748 rewriter.replaceOp(op, replacement);
749 return success();
750 }
751};
752
753struct BroadcastFoldConstantOperandsPattern
754 : public OpRewritePattern<BroadcastOp> {
755 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
756
757 LogicalResult matchAndRewrite(BroadcastOp op,
758 PatternRewriter &rewriter) const override {
759 SmallVector<int64_t, 8> foldedConstantShape;
760 SmallVector<Value, 8> newShapeOperands;
761 for (Value shape : op.getShapes()) {
762 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
763 SmallVector<int64_t, 8> newFoldedConstantShape;
765 foldedConstantShape,
766 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
767 newFoldedConstantShape)) {
768 foldedConstantShape = newFoldedConstantShape;
769 continue;
770 }
771 }
772 newShapeOperands.push_back(shape);
773 }
774
775 // Need at least two constant operands to fold anything.
776 if (op.getNumOperands() - newShapeOperands.size() < 2)
777 return failure();
778
779 auto foldedConstantOperandsTy = RankedTensorType::get(
780 {static_cast<int64_t>(foldedConstantShape.size())},
781 rewriter.getIndexType());
782 newShapeOperands.push_back(
783 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
784 rewriter.getIndexTensorAttr(foldedConstantShape)));
785 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
786 newShapeOperands);
787 return success();
788 }
789};
790
791template <typename OpTy>
792struct CanonicalizeCastExtentTensorOperandsPattern
793 : public OpRewritePattern<OpTy> {
794 using OpRewritePattern<OpTy>::OpRewritePattern;
795
796 LogicalResult matchAndRewrite(OpTy op,
797 PatternRewriter &rewriter) const override {
798 // Canonicalize operands.
799 bool anyChange = false;
800 auto canonicalizeOperand = [&](Value operand) -> Value {
801 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
802 // Only eliminate the cast if it holds no shape information.
803 bool isInformationLoosingCast =
804 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
805 if (isInformationLoosingCast) {
806 anyChange = true;
807 return castOp.getSource();
808 }
809 }
810 return operand;
811 };
812 auto newOperands = llvm::to_vector<8>(
813 llvm::map_range(op.getOperands(), canonicalizeOperand));
814
815 // Rewrite op if any change required.
816 if (!anyChange)
817 return failure();
818 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
819 return success();
820 }
821};
822
823struct BroadcastConcretizeResultTypePattern
824 : public OpRewritePattern<BroadcastOp> {
825 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
826
827 LogicalResult matchAndRewrite(BroadcastOp op,
828 PatternRewriter &rewriter) const override {
829 // Only concretize dynamic extent tensor result types.
830 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
831 if (!resultTy || !resultTy.isDynamicDim(0))
832 return failure();
833
834 // Infer resulting shape rank if possible.
835 int64_t maxRank = 0;
836 for (Value shape : op.getShapes()) {
837 if (auto extentTensorTy =
838 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
839 // Cannot infer resulting shape rank if any operand is dynamically
840 // ranked.
841 if (extentTensorTy.isDynamicDim(0))
842 return failure();
843 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
844 }
845 }
846
847 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
849 op.getShapes());
850 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
851 return success();
852 }
853};
854} // namespace
855
856void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
857 MLIRContext *context) {
858 patterns.add<BroadcastConcretizeResultTypePattern,
859 BroadcastFoldConstantOperandsPattern,
860 BroadcastForwardSingleOperandPattern,
861 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
862 RemoveDuplicateOperandsPattern<BroadcastOp>,
863 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
864}
865
866//===----------------------------------------------------------------------===//
867// ConcatOp
868//===----------------------------------------------------------------------===//
869
870OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
871 if (!adaptor.getLhs() || !adaptor.getRhs())
872 return nullptr;
873 auto lhsShape = llvm::to_vector<6>(
874 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
875 auto rhsShape = llvm::to_vector<6>(
876 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
877 SmallVector<int64_t, 6> resultShape;
878 resultShape.append(lhsShape.begin(), lhsShape.end());
879 resultShape.append(rhsShape.begin(), rhsShape.end());
880 Builder builder(getContext());
881 return builder.getIndexTensorAttr(resultShape);
882}
883
884//===----------------------------------------------------------------------===//
885// ConstShapeOp
886//===----------------------------------------------------------------------===//
887
888void ConstShapeOp::print(OpAsmPrinter &p) {
889 p << " ";
890 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
891 p << "[";
892 interleaveComma(getShape().getValues<int64_t>(), p);
893 p << "] : ";
894 p.printType(getType());
895}
896
897ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
898 if (parser.parseOptionalAttrDict(result.attributes))
899 return failure();
900 // We piggy-back on ArrayAttr parsing, though we don't internally store the
901 // shape as an ArrayAttr.
902 // TODO: Implement custom parser and maybe make syntax a bit more concise.
903 Attribute extentsRaw;
904 NamedAttrList dummy;
905 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
906 return failure();
907 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
908 if (!extentsArray)
909 return failure();
911 for (Attribute extent : extentsArray) {
912 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
913 if (!attr)
914 return failure();
915 ints.push_back(attr.getInt());
916 }
917 Builder &builder = parser.getBuilder();
918 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
919 Type resultTy;
920 if (parser.parseColonType(resultTy))
921 return failure();
922 result.types.push_back(resultTy);
923 return success();
924}
925
926OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
927
928void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
929 MLIRContext *context) {
930 patterns.add<TensorCastConstShape>(context);
931}
932
933LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
934 MLIRContext *context, std::optional<Location> location,
935 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
936 Builder b(context);
937 const Properties prop = adaptor.getProperties();
938 inferredReturnTypes.assign({RankedTensorType::get(
939 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
940 return success();
941}
942
943bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
944 TypeRange r) {
945 if (l.size() != 1 || r.size() != 1)
946 return false;
947
948 Type lhs = l.front();
949 Type rhs = r.front();
950
951 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
952 // Shape type is compatible with all other valid return types.
953 return true;
954 return lhs == rhs;
955}
956
957//===----------------------------------------------------------------------===//
958// CstrBroadcastableOp
959//===----------------------------------------------------------------------===//
960
961void CstrBroadcastableOp::getCanonicalizationPatterns(
963 // Canonicalization patterns have overlap with the considerations during
964 // folding in case additional shape information is inferred at some point that
965 // does not result in folding.
966 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
967 CstrBroadcastableEqOps,
968 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
969 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
970}
971
972// Return true if there is exactly one attribute not representing a scalar
973// broadcast.
975 bool nonScalarSeen = false;
976 for (Attribute a : attributes) {
977 if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
978 if (nonScalarSeen)
979 return false;
980 nonScalarSeen = true;
981 }
982 }
983 return true;
984}
985
986OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
987 // No broadcasting is needed if all operands but one are scalar.
988 if (hasAtMostSingleNonScalar(adaptor.getShapes()))
989 return BoolAttr::get(getContext(), true);
990
991 if ([&] {
993 for (const auto &operand : adaptor.getShapes()) {
994 if (!operand)
995 return false;
996 extents.push_back(llvm::to_vector<6>(
997 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
998 }
1000 }())
1001 return BoolAttr::get(getContext(), true);
1002
1003 // Lastly, see if folding can be completed based on what constraints are known
1004 // on the input shapes.
1005 if ([&] {
1007 for (auto shapeValue : getShapes()) {
1008 extents.emplace_back();
1009 if (failed(getShapeVec(shapeValue, extents.back())))
1010 return false;
1011 }
1013 }())
1014 return BoolAttr::get(getContext(), true);
1015
1016 // Because a failing witness result here represents an eventual assertion
1017 // failure, we do not replace it with a constant witness.
1018 return nullptr;
1019}
1020
1021LogicalResult CstrBroadcastableOp::verify() {
1022 // Ensure that CstrBroadcastableOp contains at least two operands
1023 if (getNumOperands() < 2)
1024 return emitOpError("required at least 2 input shapes");
1025 return success();
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// CstrEqOp
1030//===----------------------------------------------------------------------===//
1031
1032void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1033 MLIRContext *context) {
1034 // If inputs are equal, return passing witness
1035 patterns.add<CstrEqEqOps>(context);
1036}
1037
1038OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1039 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1040 return a && a == adaptor.getShapes().front();
1041 }))
1042 return BoolAttr::get(getContext(), true);
1043
1044 // Because a failing witness result here represents an eventual assertion
1045 // failure, we do not try to replace it with a constant witness. Similarly, we
1046 // cannot if there are any non-const inputs.
1047 return nullptr;
1048}
1049
1050//===----------------------------------------------------------------------===//
1051// ConstSizeOp
1052//===----------------------------------------------------------------------===//
1053
1054void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1055 int64_t value) {
1056 build(builder, result, builder.getIndexAttr(value));
1057}
1058
1059OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1060
1061void ConstSizeOp::getAsmResultNames(
1062 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1063 SmallString<4> buffer;
1064 llvm::raw_svector_ostream os(buffer);
1065 os << "c" << getValue();
1066 setNameFn(getResult(), os.str());
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// ConstWitnessOp
1071//===----------------------------------------------------------------------===//
1072
1073OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1074
1075//===----------------------------------------------------------------------===//
1076// CstrRequireOp
1077//===----------------------------------------------------------------------===//
1078
1079OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1080 return adaptor.getPred();
1081}
1082
1083//===----------------------------------------------------------------------===//
1084// DimOp
1085//===----------------------------------------------------------------------===//
1086
1087std::optional<int64_t> DimOp::getConstantIndex() {
1088 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1089 return constSizeOp.getValue().getLimitedValue();
1090 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1091 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1092 return std::nullopt;
1093}
1094
1095OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1096 Type valType = getValue().getType();
1097 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1098 if (!valShapedType || !valShapedType.hasRank())
1099 return nullptr;
1100 std::optional<int64_t> index = getConstantIndex();
1101 if (!index.has_value())
1102 return nullptr;
1103 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1104 return nullptr;
1105 auto extent = valShapedType.getDimSize(*index);
1106 if (ShapedType::isDynamic(extent))
1107 return nullptr;
1108 return IntegerAttr::get(IndexType::get(getContext()), extent);
1109}
1110
1111LogicalResult mlir::shape::DimOp::inferReturnTypes(
1112 MLIRContext *context, std::optional<Location> location,
1113 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1114 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1115 return success();
1116}
1117
1118bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1120}
1121
1122//===----------------------------------------------------------------------===//
1123// DivOp
1124//===----------------------------------------------------------------------===//
1125
1126OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1127 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1128 if (!lhs)
1129 return nullptr;
1130 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1131 if (!rhs || rhs.getValue().isZero())
1132 return nullptr;
1133
1134 // Division in APInt does not follow floor(lhs, rhs) when the result is
1135 // negative. Rather, APInt rounds toward zero.
1136 APInt quotient, remainder;
1137 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1138 if (quotient.isNegative() && !remainder.isZero()) {
1139 quotient -= 1;
1140 }
1141
1142 Type indexTy = IndexType::get(getContext());
1143 return IntegerAttr::get(indexTy, quotient);
1144}
1145
1146LogicalResult mlir::shape::DivOp::inferReturnTypes(
1147 MLIRContext *context, std::optional<Location> location,
1148 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1149 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1150 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1151 inferredReturnTypes.assign({SizeType::get(context)});
1152 else
1153 inferredReturnTypes.assign({IndexType::get(context)});
1154 return success();
1155}
1156
1157bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1158 // SizeType is compatible with IndexType.
1160}
1161
1162LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1163
1164//===----------------------------------------------------------------------===//
1165// ShapeEqOp
1166//===----------------------------------------------------------------------===//
1167
1168OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1169 bool allSame = true;
1170 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1171 return {};
1172 for (Attribute operand : adaptor.getShapes().drop_front()) {
1173 if (!operand)
1174 return {};
1175 allSame = allSame && operand == adaptor.getShapes().front();
1176 }
1177 return BoolAttr::get(getContext(), allSame);
1178}
1179
1180//===----------------------------------------------------------------------===//
1181// IndexToSizeOp
1182//===----------------------------------------------------------------------===//
1183
1184OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1185 // Constant values of both types, `shape.size` and `index`, are represented as
1186 // `IntegerAttr`s which makes constant folding simple.
1187 if (Attribute arg = adaptor.getArg())
1188 return arg;
1189 return {};
1190}
1191
1192void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1193 MLIRContext *context) {
1194 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1195}
1196
1197//===----------------------------------------------------------------------===//
1198// FromExtentsOp
1199//===----------------------------------------------------------------------===//
1200
1201OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1202 if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1203 return nullptr;
1205 for (auto attr : adaptor.getExtents())
1206 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1207 Builder builder(getContext());
1208 return builder.getIndexTensorAttr(extents);
1209}
1210
1211//===----------------------------------------------------------------------===//
1212// FunctionLibraryOp
1213//===----------------------------------------------------------------------===//
1214
1215void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1216 StringRef name) {
1217 result.attributes.push_back(builder.getNamedAttr(
1219}
1220
1221FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1222 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1223 getMapping().get(op->getName().getIdentifier()));
1224 if (!attr)
1225 return nullptr;
1226 return lookupSymbol<FuncOp>(attr);
1227}
1228
1229ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1231 // Parse the op name.
1232 StringAttr nameAttr;
1234 result.attributes))
1235 return failure();
1236
1237 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1238 return failure();
1239
1240 auto *bodyRegion = result.addRegion();
1241 if (parser.parseRegion(*bodyRegion))
1242 return failure();
1243
1244 if (parser.parseKeyword("mapping"))
1245 return failure();
1246
1247 DictionaryAttr mappingAttr;
1248 if (parser.parseAttribute(mappingAttr,
1249 parser.getBuilder().getType<NoneType>(), "mapping",
1250 result.attributes))
1251 return failure();
1252 return success();
1253}
1254
1255void FunctionLibraryOp::print(OpAsmPrinter &p) {
1256 p << ' ';
1257 p.printSymbolName(getName());
1259 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1260 p << ' ';
1261 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1262 /*printBlockTerminators=*/false);
1263 p << " mapping ";
1264 p.printAttributeWithoutType(getMappingAttr());
1265}
1266
1267//===----------------------------------------------------------------------===//
1268// FuncOp
1269//===----------------------------------------------------------------------===//
1270
1271FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1273 OpBuilder builder(location->getContext());
1274 OperationState state(location, getOperationName());
1275 FuncOp::build(builder, state, name, type, attrs);
1276 return cast<FuncOp>(Operation::create(state));
1277}
1278FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1280 SmallVector<NamedAttribute, 8> attrRef(attrs);
1281 return create(location, name, type, llvm::ArrayRef(attrRef));
1282}
1283FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1285 ArrayRef<DictionaryAttr> argAttrs) {
1286 FuncOp func = create(location, name, type, attrs);
1287 func.setAllArgAttrs(argAttrs);
1288 return func;
1289}
1290
1291void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1292 FunctionType type, ArrayRef<NamedAttribute> attrs,
1293 ArrayRef<DictionaryAttr> argAttrs) {
1294 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1295 builder.getStringAttr(name));
1296 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1297 TypeAttr::get(type));
1298 state.attributes.append(attrs.begin(), attrs.end());
1299 state.addRegion();
1300
1301 if (argAttrs.empty())
1302 return;
1303 assert(type.getNumInputs() == argAttrs.size());
1305 builder, state, argAttrs, /*resultAttrs=*/{},
1306 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1307}
1308
1309ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1310 auto buildFuncType =
1311 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1313 std::string &) { return builder.getFunctionType(argTypes, results); };
1314
1316 parser, result, /*allowVariadic=*/false,
1317 getFunctionTypeAttrName(result.name), buildFuncType,
1318 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1319}
1320
1321void FuncOp::print(OpAsmPrinter &p) {
1323 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1324 getArgAttrsAttrName(), getResAttrsAttrName());
1325}
1326
1327//===----------------------------------------------------------------------===//
1328// GetExtentOp
1329//===----------------------------------------------------------------------===//
1330
1331std::optional<int64_t> GetExtentOp::getConstantDim() {
1332 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1333 return constSizeOp.getValue().getLimitedValue();
1334 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1335 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1336 return std::nullopt;
1337}
1338
1339OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1340 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1341 if (!elements)
1342 return nullptr;
1343 std::optional<int64_t> dim = getConstantDim();
1344 if (!dim.has_value())
1345 return nullptr;
1346 if (dim.value() >= elements.getNumElements())
1347 return nullptr;
1348 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1349}
1350
1351void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1352 int64_t dim) {
1353 auto loc = result.location;
1354 auto dimAttr = builder.getIndexAttr(dim);
1355 if (llvm::isa<ShapeType>(shape.getType())) {
1356 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1357 build(builder, result, builder.getType<SizeType>(), shape, dim);
1358 } else {
1359 Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
1360 dimAttr);
1361 build(builder, result, builder.getIndexType(), shape, dim);
1362 }
1363}
1364
1365LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1366 MLIRContext *context, std::optional<Location> location,
1367 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1368 inferredReturnTypes.assign({IndexType::get(context)});
1369 return success();
1370}
1371
1372bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1373 TypeRange r) {
1374 // SizeType is compatible with IndexType.
1376}
1377
1378LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1379
1380//===----------------------------------------------------------------------===//
1381// IsBroadcastableOp
1382//===----------------------------------------------------------------------===//
1383
1384void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1385 MLIRContext *context) {
1386 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1387}
1388
1389OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1390 // Can always broadcast fewer than two shapes.
1391 if (adaptor.getShapes().size() < 2) {
1392 return BoolAttr::get(getContext(), true);
1393 }
1394
1395 return nullptr;
1396}
1397
1398//===----------------------------------------------------------------------===//
1399// MeetOp
1400//===----------------------------------------------------------------------===//
1401
1402LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1403 MLIRContext *context, std::optional<Location> location,
1404 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1405 if (adaptor.getOperands().empty())
1406 return failure();
1407
1408 auto isShapeType = [](Type arg) {
1409 if (llvm::isa<ShapeType>(arg))
1410 return true;
1411 return isExtentTensorType(arg);
1412 };
1413
1414 ValueRange::type_range types = adaptor.getOperands().getTypes();
1415 Type acc = types.front();
1416 for (auto t : drop_begin(types)) {
1417 Type l = acc, r = t;
1418 if (!llvm::isa<ShapeType, SizeType>(l))
1419 std::swap(l, r);
1420
1421 // Handle sizes, propagate error type if present.
1422 if (llvm::isa<SizeType>(l)) {
1423 if (llvm::isa<SizeType, IndexType>(r))
1424 acc = l;
1425 else
1426 return emitOptionalError(location, "requires all sizes or shapes");
1427 } else if (llvm::isa<IndexType>(l)) {
1428 if (llvm::isa<IndexType>(r))
1429 acc = r;
1430 else
1431 return emitOptionalError(location, "requires all sizes or shapes");
1432 } else if (llvm::isa<ShapeType>(l)) {
1433 // Handle shapes, propagate error type if present.
1434 if (isShapeType(r))
1435 acc = l;
1436 else
1437 return emitOptionalError(location, "requires all sizes or shapes");
1438 } else if (isExtentTensorType(l)) {
1439 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1440 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1441 if (ShapedType::isDynamic(rank1))
1442 acc = l;
1443 else if (ShapedType::isDynamic(rank2))
1444 acc = r;
1445 else if (rank1 != rank2)
1446 return emitOptionalError(location, "unequal shape cardinality");
1447 else
1448 acc = l;
1449 }
1450 }
1451 inferredReturnTypes.assign({acc});
1452 return success();
1453}
1454
1455bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1456 if (l.size() != 1 || r.size() != 1)
1457 return false;
1458 if (l == r)
1459 return true;
1460
1461 Type lhs = l.front();
1462 Type rhs = r.front();
1463
1464 if (!llvm::isa<ShapeType, SizeType>(lhs))
1465 std::swap(lhs, rhs);
1466
1467 if (llvm::isa<SizeType>(lhs))
1468 return llvm::isa<SizeType, IndexType>(rhs);
1469 if (llvm::isa<ShapeType>(lhs))
1470 return llvm::isa<ShapeType, TensorType>(rhs);
1471
1472 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1473 return true;
1474 return false;
1475}
1476
1477//===----------------------------------------------------------------------===//
1478// RankOp
1479//===----------------------------------------------------------------------===//
1480
1481OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1482 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1483 if (!shape)
1484 return {};
1485 int64_t rank = shape.getNumElements();
1486 Builder builder(getContext());
1487 return builder.getIndexAttr(rank);
1488}
1489
1490/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1491/// Constant folding fails in cases where only the rank is constant, not the
1492/// shape itself.
1493/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1494///
1495/// Example:
1496///
1497/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1498/// %rank = shape.rank %shape
1499///
1500/// becomes
1501///
1502/// %rank = shape.const_size 3
1503
1504namespace {
1505struct RankShapeOfCanonicalizationPattern
1506 : public OpRewritePattern<shape::RankOp> {
1507 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1508
1509 LogicalResult matchAndRewrite(shape::RankOp op,
1510 PatternRewriter &rewriter) const override {
1511 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1512 if (!shapeOfOp)
1513 return failure();
1514 auto rankedTensorType =
1515 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1516 if (!rankedTensorType)
1517 return failure();
1518 int64_t rank = rankedTensorType.getRank();
1519 if (llvm::isa<IndexType>(op.getType())) {
1520 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1521 rank);
1522 } else if (llvm::isa<shape::SizeType>(op.getType())) {
1523 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1524 } else {
1525 return failure();
1526 }
1527 return success();
1528 }
1529};
1530} // namespace
1531
1532void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1533 MLIRContext *context) {
1534 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1535}
1536
1537LogicalResult mlir::shape::RankOp::inferReturnTypes(
1538 MLIRContext *context, std::optional<Location> location,
1539 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1540 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1541 inferredReturnTypes.assign({SizeType::get(context)});
1542 else
1543 inferredReturnTypes.assign({IndexType::get(context)});
1544 return success();
1545}
1546
1547bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1548 // SizeType is compatible with IndexType.
1550}
1551
1552LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1553
1554//===----------------------------------------------------------------------===//
1555// NumElementsOp
1556//===----------------------------------------------------------------------===//
1557
1558OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1559
1560 // Fold only when argument constant.
1561 Attribute shape = adaptor.getShape();
1562 if (!shape)
1563 return {};
1564
1565 APInt product(64, 1);
1566 for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1567 product *= value;
1568 Builder builder(getContext());
1569 return builder.getIndexAttr(product.getLimitedValue());
1570}
1571
1572LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1573 MLIRContext *context, std::optional<Location> location,
1574 NumElementsOp::Adaptor adaptor,
1575 SmallVectorImpl<Type> &inferredReturnTypes) {
1576 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1577 inferredReturnTypes.assign({SizeType::get(context)});
1578 else
1579 inferredReturnTypes.assign({IndexType::get(context)});
1580 return success();
1581}
1582
1583bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1584 TypeRange r) {
1585 // SizeType is compatible with IndexType.
1587}
1588
1589LogicalResult shape::NumElementsOp::verify() {
1590 return verifySizeOrIndexOp(*this);
1591}
1592
1593//===----------------------------------------------------------------------===//
1594// MaxOp
1595//===----------------------------------------------------------------------===//
1596
1597OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1598 // If operands are equal, just propagate one.
1599 if (getLhs() == getRhs())
1600 return getLhs();
1601 return nullptr;
1602}
1603
1604LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1605 MLIRContext *context, std::optional<Location> location,
1606 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1607 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1608 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1609 else
1610 inferredReturnTypes.assign({SizeType::get(context)});
1611 return success();
1612}
1613
1614bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1615 if (l.size() != 1 || r.size() != 1)
1616 return false;
1617 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1618 return true;
1619 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1620 return true;
1621 return false;
1622}
1623
1624//===----------------------------------------------------------------------===//
1625// MinOp
1626//===----------------------------------------------------------------------===//
1627
1628OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1629 // If operands are equal, just propagate one.
1630 if (getLhs() == getRhs())
1631 return getLhs();
1632 return nullptr;
1633}
1634
1635LogicalResult mlir::shape::MinOp::inferReturnTypes(
1636 MLIRContext *context, std::optional<Location> location,
1637 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1638 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1639 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1640 else
1641 inferredReturnTypes.assign({SizeType::get(context)});
1642 return success();
1643}
1644
1645bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1646 if (l.size() != 1 || r.size() != 1)
1647 return false;
1648 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1649 return true;
1650 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1651 return true;
1652 return false;
1653}
1654
1655//===----------------------------------------------------------------------===//
1656// MulOp
1657//===----------------------------------------------------------------------===//
1658
1659OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1660 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1661 if (!lhs)
1662 return nullptr;
1663 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1664 if (!rhs)
1665 return nullptr;
1666 APInt folded = lhs.getValue() * rhs.getValue();
1667 Type indexTy = IndexType::get(getContext());
1668 return IntegerAttr::get(indexTy, folded);
1669}
1670
1671LogicalResult mlir::shape::MulOp::inferReturnTypes(
1672 MLIRContext *context, std::optional<Location> location,
1673 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1674 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1675 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1676 inferredReturnTypes.assign({SizeType::get(context)});
1677 else
1678 inferredReturnTypes.assign({IndexType::get(context)});
1679 return success();
1680}
1681
1682bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1683 // SizeType is compatible with IndexType.
1685}
1686
1687LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1688
1689//===----------------------------------------------------------------------===//
1690// ShapeOfOp
1691//===----------------------------------------------------------------------===//
1692
1693namespace {
1694/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1695struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1696 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1697
1698 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1699 PatternRewriter &rewriter) const override {
1700 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1701 if (!type || !type.hasStaticShape())
1702 return failure();
1703 Location loc = op.getLoc();
1704 Value constShape =
1705 ConstShapeOp::create(rewriter, loc,
1706 rewriter.getIndexTensorAttr(type.getShape()))
1707 .getResult();
1708 if (constShape.getType() != op.getResult().getType())
1709 constShape = tensor::CastOp::create(rewriter, loc,
1710 op.getResult().getType(), constShape);
1711 rewriter.replaceOp(op, constShape);
1712 return success();
1713 }
1714};
1715
1716// Canonicalize
1717//
1718// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1719// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1720//
1721// to
1722//
1723// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1724// %1 = %shape
1725//
1726struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1727 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1728
1729 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1730 PatternRewriter &rewriter) const override {
1731 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1732 if (!tensorReshapeOp)
1733 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1734 if (!isa<TensorType>(op.getType()))
1735 return rewriter.notifyMatchFailure(op, "result is not a tensor");
1736
1737 // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1738 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1739 // formed IR, it may not be identical (dynamically vs statically shaped),
1740 // in which case it needs to be cast first using 'tensor.cast'.
1741 // Additionally, it may not have identical element type (i32 vs index)
1742 // while it has identical shaped type (dynamic vs static), in which case it
1743 // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1744 // op result must be shape or extent tensor.
1745 Value shape = tensorReshapeOp.getShape();
1746
1747 auto opTensorTy = cast<RankedTensorType>(op.getType());
1748 auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1749
1750 if (opTensorTy != shapeTensorTy) {
1751 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1752 shape =
1753 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1754 else if (!isExtentTensorType(shapeTensorTy))
1755 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1756 shape);
1757 }
1758
1759 rewriter.replaceOp(op, shape);
1760 return success();
1761 }
1762};
1763
1764// Canonicalize
1765// ```
1766// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1767// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1768// ```
1769// to
1770// ```
1771// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1772// ```
1773struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1774 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1775
1776 LogicalResult matchAndRewrite(tensor::CastOp op,
1777 PatternRewriter &rewriter) const override {
1778 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1779 if (!ty || ty.getRank() != 1)
1780 return failure();
1781
1782 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1783 if (!shapeOfOp)
1784 return failure();
1785
1786 // Argument type must be ranked and must not conflict.
1787 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1788 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1789 return failure();
1790
1791 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1792 return success();
1793 }
1794};
1795} // namespace
1796
1797void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1798 MLIRContext *context) {
1799 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1800 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1801 context);
1802}
1803
1804LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1805 MLIRContext *context, std::optional<Location> location,
1806 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1807 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1808 inferredReturnTypes.assign({ShapeType::get(context)});
1809 else {
1810 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1811 int64_t rank =
1812 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1813 Type indexTy = IndexType::get(context);
1814 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1815 inferredReturnTypes.assign({extentTensorTy});
1816 }
1817 return success();
1818}
1819
1820bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1821 if (l.size() != 1 || r.size() != 1)
1822 return false;
1823 if (l == r)
1824 return true;
1825
1826 Type lhs = l.front();
1827 Type rhs = r.front();
1828
1829 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1830 !llvm::isa<ShapeType, ShapedType>(rhs))
1831 return false;
1832
1833 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1834 // Shape type is compatible with all other valid return types.
1835 return true;
1836
1837 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1838 return true;
1839 return false;
1840}
1841
1842LogicalResult shape::ShapeOfOp::verify() {
1843 return verifyShapeOrExtentTensorOp(*this);
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// SizeToIndexOp
1848//===----------------------------------------------------------------------===//
1849
1850OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1851 // Constant values of both types, `shape.size` and `index`, are represented as
1852 // `IntegerAttr`s which makes constant folding simple.
1853 if (Attribute arg = adaptor.getArg())
1854 return arg;
1855 return OpFoldResult();
1856}
1857
1858void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1859 MLIRContext *context) {
1860 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1861}
1862
1863bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1864 if (inputs.size() != 1 || outputs.size() != 1)
1865 return false;
1866 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1867 llvm::isa<IndexType>(outputs[0]);
1868}
1869
1870//===----------------------------------------------------------------------===//
1871// YieldOp
1872//===----------------------------------------------------------------------===//
1873
1874LogicalResult shape::YieldOp::verify() {
1875 auto *parentOp = (*this)->getParentOp();
1876 auto results = parentOp->getResults();
1877 auto operands = getOperands();
1878
1879 if (parentOp->getNumResults() != getNumOperands())
1880 return emitOpError() << "number of operands does not match number of "
1881 "results of its parent";
1882 for (auto e : llvm::zip(results, operands))
1883 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1884 return emitOpError() << "types mismatch between yield op and its parent";
1885
1886 return success();
1887}
1888
1889//===----------------------------------------------------------------------===//
1890// SplitAtOp
1891//===----------------------------------------------------------------------===//
1892
1893LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1895 if (!adaptor.getOperand() || !adaptor.getIndex())
1896 return failure();
1897 auto shapeVec = llvm::to_vector<6>(
1898 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1899 auto shape = llvm::ArrayRef(shapeVec);
1900 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1901 // Verify that the split point is in the correct range.
1902 // TODO: Constant fold to an "error".
1903 int64_t rank = shape.size();
1904 if (-rank > splitPoint || splitPoint > rank)
1905 return failure();
1906 if (splitPoint < 0)
1907 splitPoint += shape.size();
1908 Builder builder(adaptor.getOperand().getContext());
1909 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1910 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1911 return success();
1912}
1913
1914//===----------------------------------------------------------------------===//
1915// ToExtentTensorOp
1916//===----------------------------------------------------------------------===//
1917
1918OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1919 if (!adaptor.getInput())
1920 return OpFoldResult();
1921 Builder builder(getContext());
1922 auto shape = llvm::to_vector<6>(
1923 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1924 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1925 builder.getIndexType());
1926 return DenseIntElementsAttr::get(type, shape);
1927}
1928
1929bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1930 if (inputs.size() != 1 || outputs.size() != 1)
1931 return false;
1932 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1933 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1934 inputTensor.getRank() != 1)
1935 return false;
1936 } else if (!llvm::isa<ShapeType>(inputs[0])) {
1937 return false;
1938 }
1939
1940 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1941 return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1942}
1943
1944//===----------------------------------------------------------------------===//
1945// ReduceOp
1946//===----------------------------------------------------------------------===//
1947
1948void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1949 ValueRange initVals) {
1950 OpBuilder::InsertionGuard g(builder);
1951 result.addOperands(shape);
1952 result.addOperands(initVals);
1953
1954 Region *bodyRegion = result.addRegion();
1955 Block *bodyBlock = builder.createBlock(
1956 bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1957
1958 Type elementType;
1959 if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1960 elementType = tensorType.getElementType();
1961 else
1962 elementType = SizeType::get(builder.getContext());
1963 bodyBlock->addArgument(elementType, shape.getLoc());
1964
1965 for (Value initVal : initVals) {
1966 bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1967 result.addTypes(initVal.getType());
1968 }
1969}
1970
1971LogicalResult ReduceOp::verify() {
1972 // Verify block arg types.
1973 Block &block = getRegion().front();
1974
1975 // The block takes index, extent, and aggregated values as arguments.
1976 auto blockArgsCount = getInitVals().size() + 2;
1977 if (block.getNumArguments() != blockArgsCount)
1978 return emitOpError() << "ReduceOp body is expected to have "
1979 << blockArgsCount << " arguments";
1980
1981 // The first block argument is the index and must always be of type `index`.
1982 if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1983 return emitOpError(
1984 "argument 0 of ReduceOp body is expected to be of IndexType");
1985
1986 // The second block argument is the extent and must be of type `size` or
1987 // `index`, depending on whether the reduce operation is applied to a shape or
1988 // to an extent tensor.
1989 Type extentTy = block.getArgument(1).getType();
1990 if (llvm::isa<ShapeType>(getShape().getType())) {
1991 if (!llvm::isa<SizeType>(extentTy))
1992 return emitOpError("argument 1 of ReduceOp body is expected to be of "
1993 "SizeType if the ReduceOp operates on a ShapeType");
1994 } else {
1995 if (!llvm::isa<IndexType>(extentTy))
1996 return emitOpError(
1997 "argument 1 of ReduceOp body is expected to be of IndexType if the "
1998 "ReduceOp operates on an extent tensor");
1999 }
2000
2001 for (const auto &type : llvm::enumerate(getInitVals()))
2002 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2003 return emitOpError() << "type mismatch between argument "
2004 << type.index() + 2
2005 << " of ReduceOp body and initial value "
2006 << type.index();
2007 return success();
2008}
2009
2010ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2011 // Parse operands.
2013 Type shapeOrExtentTensorType;
2014 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2016 parser.parseColonType(shapeOrExtentTensorType) ||
2017 parser.parseOptionalArrowTypeList(result.types))
2018 return failure();
2019
2020 // Resolve operands.
2021 auto initVals = llvm::ArrayRef(operands).drop_front();
2022 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2023 result.operands) ||
2024 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2025 result.operands))
2026 return failure();
2027
2028 // Parse the body.
2029 Region *body = result.addRegion();
2030 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2031 return failure();
2032
2033 // Parse attributes.
2034 if (parser.parseOptionalAttrDict(result.attributes))
2035 return failure();
2036
2037 return success();
2038}
2039
2040void ReduceOp::print(OpAsmPrinter &p) {
2041 p << '(' << getShape() << ", " << getInitVals()
2042 << ") : " << getShape().getType();
2043 p.printOptionalArrowTypeList(getResultTypes());
2044 p << ' ';
2045 p.printRegion(getRegion());
2046 p.printOptionalAttrDict((*this)->getAttrs());
2047}
2048
2049#define GET_OP_CLASSES
2050#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2051
2052#define GET_TYPEDEF_CLASSES
2053#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition Shape.cpp:66
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition Shape.cpp:974
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition Shape.cpp:83
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition Shape.cpp:96
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition Shape.cpp:71
static int64_t product(ArrayRef< int64_t > vals)
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
Operation & back()
Definition Block.h:152
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:193
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This is the interface that must be implemented by the dialects of operations to be inlined.
DialectInlinerInterface(Dialect *dialect)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
iterator_range< dialect_attr_iterator > dialect_attr_range
Definition Operation.h:634
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
ValueTypeRange< ValueRange > type_range
Definition ValueRange.h:418
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
bool isExtentTensorType(Type)
Definition Shape.cpp:44
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition Shape.cpp:49
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.