MLIR 23.0.0git
Shape.cpp
Go to the documentation of this file.
1//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/Dialect/Traits.h"
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/SetOperations.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/raw_ostream.h"
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::shape;
33
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35
36namespace {
37#include "ShapeCanonicalization.inc"
38} // namespace
39
40RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41 return RankedTensorType::get({rank}, IndexType::get(ctx));
42}
43
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47}
48
49LogicalResult shape::getShapeVec(Value input,
50 SmallVectorImpl<int64_t> &shapeValues) {
51 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
53 if (!type.hasRank())
54 return failure();
55 llvm::append_range(shapeValues, type.getShape());
56 return success();
57 }
59 if (matchPattern(input, m_Constant(&attr))) {
60 llvm::append_range(shapeValues, attr.getValues<int64_t>());
61 return success();
62 }
63 return failure();
64}
65
66static bool isErrorPropagationPossible(TypeRange operandTypes) {
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
69}
70
71static LogicalResult verifySizeOrIndexOp(Operation *op) {
72 assert(op != nullptr && op->getNumResults() == 1);
73 Type resultTy = op->getResultTypes().front();
75 if (!llvm::isa<SizeType>(resultTy))
76 return op->emitOpError()
77 << "if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
79 }
80 return success();
81}
82
83static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84 assert(op != nullptr && op->getNumResults() == 1);
85 Type resultTy = op->getResultTypes().front();
87 if (!llvm::isa<ShapeType>(resultTy))
88 return op->emitOpError()
89 << "if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
91 }
92 return success();
93}
94
95template <typename... Ty>
96static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
98}
99
100template <typename... Ty, typename... ranges>
101static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103}
104
105//===----------------------------------------------------------------------===//
106// InlinerInterface
107//===----------------------------------------------------------------------===//
108
109namespace {
110/// This class defines the interface for inlining shape dialect ops.
111struct ShapeInlinerInterface : public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
113
114 // Returns true if the given region 'src' can be inlined into the region
115 // 'dest' that is attached to an operation registered to the current dialect.
116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117 IRMapping &) const final {
118 return true;
119 }
120
121 // Returns true if the given operation 'op', that is registered to this
122 // dialect, can be inlined into the region 'dest' that is attached to an
123 // operation registered to the current dialect.
124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125 IRMapping &) const final {
126 return true;
127 }
128};
129} // namespace
130
131void ShapeDialect::initialize() {
132 addOperations<
133#define GET_OP_LIST
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135 >();
136 addTypes<
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139 >();
140 addInterfaces<ShapeInlinerInterface>();
141 // Allow unknown operations during prototyping and testing. As the dialect is
142 // still evolving it makes it simple to start with an unregistered ops and
143 // try different variants before actually defining the op.
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
146 AssumingYieldOp>();
147}
148
149Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
150 Attribute value, Type type,
151 Location loc) {
152 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
154
155 if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
164
165 return arith::ConstantOp::materialize(builder, value, type, loc);
166}
167
168LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
169 NamedAttribute attribute) {
170 // Verify shape.lib attribute.
171 if (attribute.getName() == "shape.lib") {
172 if (!op->hasTrait<OpTrait::SymbolTable>())
173 return op->emitError(
174 "shape.lib attribute may only be on op implementing SymbolTable");
175
176 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
177 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
178 if (!symbol)
179 return op->emitError("shape function library ")
180 << symbolRef << " not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
182 ? success()
183 : op->emitError()
184 << symbolRef << " required to be shape function library";
185 }
186
187 if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
188 // Verify all entries are function libraries and mappings in libraries
189 // refer to unique ops.
191 for (auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
193 return op->emitError(
194 "only SymbolRefAttr allowed in shape.lib attribute array");
195
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
197 SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
198 if (!shapeFnLib)
199 return op->emitError()
200 << it << " does not refer to FunctionLibraryOp";
201 for (auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->emitError("only one op to shape mapping allowed, found "
204 "multiple for `")
205 << mapping.getName() << "`";
206 }
207 }
208 }
209 return success();
210 }
211
212 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
214 }
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// AnyOp
220//===----------------------------------------------------------------------===//
221
222// TODO: Canonicalization should be implemented for shapes that can be
223// determined through mixtures of the known dimensions of the inputs.
224OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
225 // Only the last operand is checked because AnyOp is commutative.
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
228
229 return nullptr;
230}
231
232//===----------------------------------------------------------------------===//
233// AssumingOp
234//===----------------------------------------------------------------------===//
235
236ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
237 result.regions.reserve(1);
238 Region *doRegion = result.addRegion();
239
240 auto &builder = parser.getBuilder();
242 if (parser.parseOperand(cond) ||
243 parser.resolveOperand(cond, builder.getType<WitnessType>(),
244 result.operands))
245 return failure();
246
247 // Parse optional results type list.
248 if (parser.parseOptionalArrowTypeList(result.types))
249 return failure();
250
251 // Parse the region and add a terminator if elided.
252 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
253 return failure();
254 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
255
256 // Parse the optional attribute list.
257 if (parser.parseOptionalAttrDict(result.attributes))
258 return failure();
259 return success();
260}
261
262void AssumingOp::print(OpAsmPrinter &p) {
263 bool yieldsResults = !getResults().empty();
264
265 p << " " << getWitness();
266 if (yieldsResults)
267 p << " -> (" << getResultTypes() << ")";
268 p << ' ';
269 p.printRegion(getDoRegion(),
270 /*printEntryBlockArgs=*/false,
271 /*printBlockTerminators=*/yieldsResults);
272 p.printOptionalAttrDict((*this)->getAttrs());
273}
274
275namespace {
276// Removes AssumingOp with a passing witness and inlines the region.
277struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter) const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
284 return failure();
285
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
287 return success();
288 }
289};
290
291struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
293
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter) const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
298
299 // Find used values.
300 SmallVector<Value, 4> newYieldOperands;
301 for (auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
305 }
306 }
307
308 // Rewrite only if redundant results exist.
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
310 return failure();
311
312 // Replace yield op in the old assuming op's body and move the entire region
313 // to the new assuming op.
314 rewriter.setInsertionPointToEnd(body);
315 auto newYieldOp =
316 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
317 rewriter.setInsertionPoint(op);
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
321
322 // Use the new results to replace the previously used ones.
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(nullptr);
328 else
329 replacementValues.push_back(*src++);
330 }
331 rewriter.replaceOp(op, replacementValues);
332 return success();
333 }
334};
335} // namespace
336
337void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
338 MLIRContext *context) {
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
340}
341
342// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
343void AssumingOp::getSuccessorRegions(
345 // AssumingOp has unconditional control flow into the region and back to the
346 // parent, so return the correct RegionSuccessor purely based on the index
347 // being None or 0.
348 if (!point.isParent()) {
349 regions.push_back(RegionSuccessor::parent());
350 return;
351 }
352
353 regions.push_back(RegionSuccessor(&getDoRegion()));
354}
355
356ValueRange AssumingOp::getSuccessorInputs(RegionSuccessor successor) {
357 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
358}
359
360void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
361 PatternRewriter &rewriter) {
362 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
363 auto *assumingBlock = op.getBody();
364 auto initPosition = rewriter.getInsertionPoint();
365 auto *blockAfterAssuming =
366 rewriter.splitBlock(blockBeforeAssuming, initPosition);
367
368 // Remove the AssumingOp and AssumingYieldOp.
369 auto &yieldOp = assumingBlock->back();
370 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
371 rewriter.replaceOp(op, yieldOp.getOperands());
372 rewriter.eraseOp(&yieldOp);
373
374 // Merge blocks together as there was no branching behavior from the
375 // AssumingOp.
376 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
377 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
378}
379
380void AssumingOp::build(
381 OpBuilder &builder, OperationState &result, Value witness,
383 OpBuilder::InsertionGuard g(builder);
384
385 result.addOperands(witness);
386 Region *bodyRegion = result.addRegion();
387 builder.createBlock(bodyRegion);
388
389 // Build body.
390 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
391 AssumingYieldOp::create(builder, result.location, yieldValues);
392
393 SmallVector<Type, 2> assumingTypes;
394 for (Value v : yieldValues)
395 assumingTypes.push_back(v.getType());
396 result.addTypes(assumingTypes);
397}
398
399//===----------------------------------------------------------------------===//
400// AddOp
401//===----------------------------------------------------------------------===//
402
403LogicalResult mlir::shape::AddOp::inferReturnTypes(
404 MLIRContext *context, std::optional<Location> location,
405 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
406 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
407 llvm::isa<SizeType>(adaptor.getRhs().getType()))
408 inferredReturnTypes.assign({SizeType::get(context)});
409 else
410 inferredReturnTypes.assign({IndexType::get(context)});
411 return success();
412}
413
414bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
415 // SizeType is compatible with IndexType.
417}
418
419OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
420 // add(x, 0) -> x
421 if (matchPattern(getRhs(), m_Zero()))
422 return getLhs();
423
425 adaptor.getOperands(),
426 [](APInt a, const APInt &b) { return std::move(a) + b; });
427}
428
429LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
430
431//===----------------------------------------------------------------------===//
432// AssumingAllOp
433//===----------------------------------------------------------------------===//
434
435namespace {
436
437// Merge multiple `shape.assuming_all` operations together.
438//
439// %0 = shape.assuming_all %w0, %w1
440// %1 = shape.assuming_all %w2, %0
441//
442// to:
443//
444// %0 = shape.assuming_all %w0, %w2, %w2
445struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
446 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
447
448 LogicalResult matchAndRewrite(AssumingAllOp op,
449 PatternRewriter &rewriter) const override {
450 SmallVector<Value> operands;
451
452 for (Value operand : op.getInputs()) {
453 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
454 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
455 else
456 operands.push_back(operand);
457 }
458
459 // We didn't find any other `assuming_all` ops to merge with.
460 if (operands.size() == op.getNumOperands())
461 return failure();
462
463 // Replace with a new `assuming_all` operation with merged constraints.
464 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
465 return success();
466 }
467};
468
469// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
470// are subsumed by others.
471//
472// %0 = shape.cstr_broadcastable %shape0, %shape1
473// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
474//
475// %2 = shape.cstr_broadcastable %shape3, %shape4
476// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
477//
478// %4 = shape.assuming_all %0, %1, %2, %3
479//
480// to:
481//
482// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
483// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
484// %2 = shape.assuming_all %0, %1
485//
486// In this example if shapes [0, 1, 2] are broadcastable, then it means that
487// shapes [0, 1] are broadcastable too, and can be removed from the list of
488// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
489// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
490struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
491 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
492
493 LogicalResult matchAndRewrite(AssumingAllOp op,
494 PatternRewriter &rewriter) const override {
495 // Collect all `CstrBroadcastableOp` operands first.
497 for (Value operand : op.getInputs()) {
498 // TODO: Apply this optimization if some of the witnesses are not
499 // produced by the `cstr_broadcastable`.
500 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
501 if (!broadcastable)
502 return failure();
503
504 operands.insert(broadcastable);
505 }
506
507 // Skip trivial `assuming_all` operations.
508 if (operands.size() <= 1)
509 return failure();
510
511 // Collect shapes checked by `cstr_broadcastable` operands.
512 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
513 for (auto cstr : operands) {
514 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
515 shapes.emplace_back(cstr, std::move(shapesSet));
516 }
517
518 // Sort by the number of shape operands (larger to smaller).
519 llvm::sort(shapes, [](auto a, auto b) {
520 return a.first.getNumOperands() > b.first.getNumOperands();
521 });
522
523 // We start from the `cst_broadcastable` operations with largest number of
524 // shape operands, and remove redundant `cst_broadcastable` operations. We
525 // do this until we find a set of `cst_broadcastable` operations with
526 // non-overlapping constraints.
527 SmallVector<CstrBroadcastableOp> markedForErase;
528
529 for (unsigned i = 0; i < shapes.size(); ++i) {
530 auto isSubset = [&](auto pair) {
531 return llvm::set_is_subset(pair.second, shapes[i].second);
532 };
533
534 // Keep redundant `cstr_broadcastable` operations to be erased.
535 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
536 for (auto *it0 = it; it0 < shapes.end(); ++it0)
537 markedForErase.push_back(it0->first);
538 shapes.erase(it, shapes.end());
539 }
540
541 // We didn't find any operands that could be removed.
542 if (markedForErase.empty())
543 return failure();
544
545 // Collect non-overlapping `cst_broadcastable` constraints.
546 SmallVector<Value> uniqueConstraints;
547 for (auto &shape : shapes)
548 uniqueConstraints.push_back(shape.first.getResult());
549
550 // Replace with a new `assuming_all` operation ...
551 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
552
553 // ... and maybe erase `cstr_broadcastable` ops without uses.
554 for (auto &op : markedForErase)
555 if (op->use_empty())
556 rewriter.eraseOp(op);
557
558 return success();
559 }
560};
561
562struct AssumingAllToCstrEqCanonicalization
563 : public OpRewritePattern<AssumingAllOp> {
564 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
565
566 LogicalResult matchAndRewrite(AssumingAllOp op,
567 PatternRewriter &rewriter) const override {
568 SmallVector<Value, 8> shapes;
569 for (Value w : op.getInputs()) {
570 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
571 if (!cstrEqOp)
572 return failure();
573 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
574 return llvm::is_contained(shapes, s);
575 });
576 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
577 return failure();
578 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
579 }
580 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
581 return success();
582 }
583};
584
585template <typename OpTy>
586struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
587 using OpRewritePattern<OpTy>::OpRewritePattern;
588
589 LogicalResult matchAndRewrite(OpTy op,
590 PatternRewriter &rewriter) const override {
591 // Find unique operands.
592 SetVector<Value> unique(op.operand_begin(), op.operand_end());
593
594 // Reduce op to equivalent with unique operands.
595 if (unique.size() < op.getNumOperands()) {
596 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
597 unique.takeVector(), op->getAttrs());
598 return success();
599 }
600
601 return failure();
602 }
603};
604} // namespace
605
606void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
607 MLIRContext *context) {
609 .add<MergeAssumingAllOps, AssumingAllOneOp,
610 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
611 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
612}
613
614OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
615 // Iterate in reverse to first handle all constant operands. They are
616 // guaranteed to be the tail of the inputs because this is commutative.
617 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
618 Attribute a = adaptor.getInputs()[idx];
619 // Cannot fold if any inputs are not constant;
620 if (!a)
621 return nullptr;
622
623 // We do not need to keep statically known values after handling them in
624 // this method.
625 getOperation()->eraseOperand(idx);
626
627 // Always false if any input is statically known false
628 if (!llvm::cast<BoolAttr>(a).getValue())
629 return a;
630 }
631 // If this is reached, all inputs were statically known passing.
632 return BoolAttr::get(getContext(), true);
633}
634
635LogicalResult AssumingAllOp::verify() {
636 // Ensure that AssumingAllOp contains at least one operand
637 if (getNumOperands() == 0)
638 return emitOpError("no operands specified");
639
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// BroadcastOp
645//===----------------------------------------------------------------------===//
646
647OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
648 if (getShapes().size() == 1) {
649 // Otherwise, we need a cast which would be a canonicalization, not folding.
650 if (getShapes().front().getType() != getType())
651 return nullptr;
652 return getShapes().front();
653 }
654
655 if (!adaptor.getShapes().front())
656 return nullptr;
657
658 SmallVector<int64_t, 6> resultShape(
659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
660 .getValues<int64_t>());
661
662 for (auto next : adaptor.getShapes().drop_front()) {
663 if (!next)
664 return nullptr;
665 auto nextShape = llvm::to_vector<6>(
666 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
667
669 // If the shapes are not compatible, we can't fold it.
670 // TODO: Fold to an "error".
671 if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
672 return nullptr;
673
674 resultShape.clear();
675 std::copy(tmpShape.begin(), tmpShape.end(),
676 std::back_inserter(resultShape));
677 }
678
679 Builder builder(getContext());
680 return builder.getIndexTensorAttr(resultShape);
681}
682
683LogicalResult BroadcastOp::verify() {
684 return verifyShapeOrExtentTensorOp(*this);
685}
686
687namespace {
688template <typename OpTy>
689struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
690 using OpRewritePattern<OpTy>::OpRewritePattern;
691
692 LogicalResult matchAndRewrite(OpTy op,
693 PatternRewriter &rewriter) const override {
694 auto isPotentiallyNonEmptyShape = [](Value shape) {
695 if (auto extentTensorTy =
696 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
697 if (extentTensorTy.getDimSize(0) == 0)
698 return false;
699 }
700 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
701 if (constShape.getShape().empty())
702 return false;
703 }
704 return true;
705 };
706 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
707 isPotentiallyNonEmptyShape);
708
709 // Replace the op with empty shape constant if all operants are reduced to
710 // be empty.
711 if (newOperands.empty()) {
712 rewriter.replaceOpWithNewOp<ConstShapeOp>(
713 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
714 return success();
715 }
716
717 // Reduce op to equivalent without empty shape operands.
718 if (newOperands.size() < op.getNumOperands()) {
719 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
720 op->getAttrs());
721 return success();
722 }
723
724 return failure();
725 }
726};
727
728struct BroadcastForwardSingleOperandPattern
729 : public OpRewritePattern<BroadcastOp> {
730 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
731
732 LogicalResult matchAndRewrite(BroadcastOp op,
733 PatternRewriter &rewriter) const override {
734 if (op.getNumOperands() != 1)
735 return failure();
736 Value replacement = op.getShapes().front();
737
738 // Insert cast if needed.
739 if (replacement.getType() != op.getType()) {
740 auto loc = op.getLoc();
741 if (llvm::isa<ShapeType>(op.getType())) {
742 replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
743 } else {
744 assert(!llvm::isa<ShapeType>(op.getType()) &&
745 !llvm::isa<ShapeType>(replacement.getType()) &&
746 "expect extent tensor cast");
748 tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
749 }
750 }
751
752 rewriter.replaceOp(op, replacement);
753 return success();
754 }
755};
756
757struct BroadcastFoldConstantOperandsPattern
758 : public OpRewritePattern<BroadcastOp> {
759 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
760
761 LogicalResult matchAndRewrite(BroadcastOp op,
762 PatternRewriter &rewriter) const override {
763 SmallVector<int64_t, 8> foldedConstantShape;
764 SmallVector<Value, 8> newShapeOperands;
765 for (Value shape : op.getShapes()) {
766 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
767 SmallVector<int64_t, 8> newFoldedConstantShape;
769 foldedConstantShape,
770 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
771 newFoldedConstantShape)) {
772 foldedConstantShape = newFoldedConstantShape;
773 continue;
774 }
775 }
776 newShapeOperands.push_back(shape);
777 }
778
779 // Need at least two constant operands to fold anything.
780 if (op.getNumOperands() - newShapeOperands.size() < 2)
781 return failure();
782
783 auto foldedConstantOperandsTy = RankedTensorType::get(
784 {static_cast<int64_t>(foldedConstantShape.size())},
785 rewriter.getIndexType());
786 newShapeOperands.push_back(
787 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
788 rewriter.getIndexTensorAttr(foldedConstantShape)));
789 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
790 newShapeOperands);
791 return success();
792 }
793};
794
795template <typename OpTy>
796struct CanonicalizeCastExtentTensorOperandsPattern
797 : public OpRewritePattern<OpTy> {
798 using OpRewritePattern<OpTy>::OpRewritePattern;
799
800 LogicalResult matchAndRewrite(OpTy op,
801 PatternRewriter &rewriter) const override {
802 // Canonicalize operands.
803 bool anyChange = false;
804 auto canonicalizeOperand = [&](Value operand) -> Value {
805 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
806 // Only eliminate the cast if it holds no shape information.
807 bool isInformationLoosingCast =
808 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
809 if (isInformationLoosingCast) {
810 anyChange = true;
811 return castOp.getSource();
812 }
813 }
814 return operand;
815 };
816 auto newOperands =
817 llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand);
818
819 // Rewrite op if any change required.
820 if (!anyChange)
821 return failure();
822 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
823 return success();
824 }
825};
826
827struct BroadcastConcretizeResultTypePattern
828 : public OpRewritePattern<BroadcastOp> {
829 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
830
831 LogicalResult matchAndRewrite(BroadcastOp op,
832 PatternRewriter &rewriter) const override {
833 // Only concretize dynamic extent tensor result types.
834 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
835 if (!resultTy || !resultTy.isDynamicDim(0))
836 return failure();
837
838 // Infer resulting shape rank if possible.
839 int64_t maxRank = 0;
840 for (Value shape : op.getShapes()) {
841 if (auto extentTensorTy =
842 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
843 // Cannot infer resulting shape rank if any operand is dynamically
844 // ranked.
845 if (extentTensorTy.isDynamicDim(0))
846 return failure();
847 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
848 }
849 }
850
851 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
853 op.getShapes());
854 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
855 return success();
856 }
857};
858} // namespace
859
860void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
861 MLIRContext *context) {
862 patterns.add<BroadcastConcretizeResultTypePattern,
863 BroadcastFoldConstantOperandsPattern,
864 BroadcastForwardSingleOperandPattern,
865 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
866 RemoveDuplicateOperandsPattern<BroadcastOp>,
867 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
868}
869
870//===----------------------------------------------------------------------===//
871// ConcatOp
872//===----------------------------------------------------------------------===//
873
874OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
875 if (!adaptor.getLhs() || !adaptor.getRhs())
876 return nullptr;
877 auto lhsShape = llvm::to_vector<6>(
878 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
879 auto rhsShape = llvm::to_vector<6>(
880 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
881 SmallVector<int64_t, 6> resultShape;
882 resultShape.append(lhsShape.begin(), lhsShape.end());
883 resultShape.append(rhsShape.begin(), rhsShape.end());
884 Builder builder(getContext());
885 return builder.getIndexTensorAttr(resultShape);
886}
887
888//===----------------------------------------------------------------------===//
889// ConstShapeOp
890//===----------------------------------------------------------------------===//
891
892void ConstShapeOp::print(OpAsmPrinter &p) {
893 p << " ";
894 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
895 p << "[";
896 interleaveComma(getShape().getValues<int64_t>(), p);
897 p << "] : ";
898 p.printType(getType());
899}
900
901ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
902 if (parser.parseOptionalAttrDict(result.attributes))
903 return failure();
904 // We piggy-back on ArrayAttr parsing, though we don't internally store the
905 // shape as an ArrayAttr.
906 // TODO: Implement custom parser and maybe make syntax a bit more concise.
907 Attribute extentsRaw;
908 NamedAttrList dummy;
909 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
910 return failure();
911 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
912 if (!extentsArray)
913 return failure();
915 for (Attribute extent : extentsArray) {
916 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
917 if (!attr)
918 return failure();
919 ints.push_back(attr.getInt());
920 }
921 Builder &builder = parser.getBuilder();
922 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
923 Type resultTy;
924 if (parser.parseColonType(resultTy))
925 return failure();
926 result.types.push_back(resultTy);
927 return success();
928}
929
930OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
931
932void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
933 MLIRContext *context) {
934 patterns.add<TensorCastConstShape>(context);
935}
936
937LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
938 MLIRContext *context, std::optional<Location> location,
939 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
940 Builder b(context);
941 const Properties prop = adaptor.getProperties();
942 inferredReturnTypes.assign({RankedTensorType::get(
943 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
944 return success();
945}
946
947bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
948 TypeRange r) {
949 if (l.size() != 1 || r.size() != 1)
950 return false;
951
952 Type lhs = l.front();
953 Type rhs = r.front();
954
955 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
956 // Shape type is compatible with all other valid return types.
957 return true;
958 return lhs == rhs;
959}
960
961//===----------------------------------------------------------------------===//
962// CstrBroadcastableOp
963//===----------------------------------------------------------------------===//
964
965void CstrBroadcastableOp::getCanonicalizationPatterns(
967 // Canonicalization patterns have overlap with the considerations during
968 // folding in case additional shape information is inferred at some point that
969 // does not result in folding.
970 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
971 CstrBroadcastableEqOps,
972 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
973 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
974}
975
976// Return true if there is exactly one attribute not representing a scalar
977// broadcast.
979 bool nonScalarSeen = false;
980 for (Attribute a : attributes) {
981 if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
982 if (nonScalarSeen)
983 return false;
984 nonScalarSeen = true;
985 }
986 }
987 return true;
988}
989
990OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
991 // No broadcasting is needed if all operands but one are scalar.
992 if (hasAtMostSingleNonScalar(adaptor.getShapes()))
993 return BoolAttr::get(getContext(), true);
994
995 if ([&] {
997 for (const auto &operand : adaptor.getShapes()) {
998 if (!operand)
999 return false;
1000 extents.push_back(llvm::to_vector<6>(
1001 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1002 }
1004 }())
1005 return BoolAttr::get(getContext(), true);
1006
1007 // Lastly, see if folding can be completed based on what constraints are known
1008 // on the input shapes.
1009 if ([&] {
1011 for (auto shapeValue : getShapes()) {
1012 extents.emplace_back();
1013 if (failed(getShapeVec(shapeValue, extents.back())))
1014 return false;
1015 }
1017 }())
1018 return BoolAttr::get(getContext(), true);
1019
1020 // Because a failing witness result here represents an eventual assertion
1021 // failure, we do not replace it with a constant witness.
1022 return nullptr;
1023}
1024
1025LogicalResult CstrBroadcastableOp::verify() {
1026 // Ensure that CstrBroadcastableOp contains at least two operands
1027 if (getNumOperands() < 2)
1028 return emitOpError("required at least 2 input shapes");
1029 return success();
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// CstrEqOp
1034//===----------------------------------------------------------------------===//
1035
1036void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1037 MLIRContext *context) {
1038 // If inputs are equal, return passing witness
1039 patterns.add<CstrEqEqOps>(context);
1040}
1041
1042OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1043 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1044 return a && a == adaptor.getShapes().front();
1045 }))
1046 return BoolAttr::get(getContext(), true);
1047
1048 // Because a failing witness result here represents an eventual assertion
1049 // failure, we do not try to replace it with a constant witness. Similarly, we
1050 // cannot if there are any non-const inputs.
1051 return nullptr;
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// ConstSizeOp
1056//===----------------------------------------------------------------------===//
1057
1058void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1059 int64_t value) {
1060 build(builder, result, builder.getIndexAttr(value));
1061}
1062
1063OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1064
1065void ConstSizeOp::getAsmResultNames(
1066 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1067 SmallString<4> buffer;
1068 llvm::raw_svector_ostream os(buffer);
1069 os << "c" << getValue();
1070 setNameFn(getResult(), os.str());
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// ConstWitnessOp
1075//===----------------------------------------------------------------------===//
1076
1077OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1078
1079//===----------------------------------------------------------------------===//
1080// CstrRequireOp
1081//===----------------------------------------------------------------------===//
1082
1083OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1084 return adaptor.getPred();
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// DimOp
1089//===----------------------------------------------------------------------===//
1090
1091std::optional<int64_t> DimOp::getConstantIndex() {
1092 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1093 return constSizeOp.getValue().getLimitedValue();
1094 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1095 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1096 return std::nullopt;
1097}
1098
1099OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1100 Type valType = getValue().getType();
1101 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1102 if (!valShapedType || !valShapedType.hasRank())
1103 return nullptr;
1104 std::optional<int64_t> index = getConstantIndex();
1105 if (!index.has_value())
1106 return nullptr;
1107 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1108 return nullptr;
1109 auto extent = valShapedType.getDimSize(*index);
1110 if (ShapedType::isDynamic(extent))
1111 return nullptr;
1112 return IntegerAttr::get(IndexType::get(getContext()), extent);
1113}
1114
1115LogicalResult mlir::shape::DimOp::inferReturnTypes(
1116 MLIRContext *context, std::optional<Location> location,
1117 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1118 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1119 return success();
1120}
1121
1122bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// DivOp
1128//===----------------------------------------------------------------------===//
1129
1130OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1131 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1132 if (!lhs)
1133 return nullptr;
1134 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1135 if (!rhs || rhs.getValue().isZero())
1136 return nullptr;
1137
1138 // Division in APInt does not follow floor(lhs, rhs) when the result is
1139 // negative. Rather, APInt rounds toward zero.
1140 APInt quotient, remainder;
1141 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1142 if (quotient.isNegative() && !remainder.isZero()) {
1143 quotient -= 1;
1144 }
1145
1146 Type indexTy = IndexType::get(getContext());
1147 return IntegerAttr::get(indexTy, quotient);
1148}
1149
1150LogicalResult mlir::shape::DivOp::inferReturnTypes(
1151 MLIRContext *context, std::optional<Location> location,
1152 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1153 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1154 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1155 inferredReturnTypes.assign({SizeType::get(context)});
1156 else
1157 inferredReturnTypes.assign({IndexType::get(context)});
1158 return success();
1159}
1160
1161bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1162 // SizeType is compatible with IndexType.
1164}
1165
1166LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1167
1168//===----------------------------------------------------------------------===//
1169// ShapeEqOp
1170//===----------------------------------------------------------------------===//
1171
1172OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1173 bool allSame = true;
1174 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1175 return {};
1176 for (Attribute operand : adaptor.getShapes().drop_front()) {
1177 if (!operand)
1178 return {};
1179 allSame = allSame && operand == adaptor.getShapes().front();
1180 }
1181 return BoolAttr::get(getContext(), allSame);
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// IndexToSizeOp
1186//===----------------------------------------------------------------------===//
1187
1188OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1189 // Constant values of both types, `shape.size` and `index`, are represented as
1190 // `IntegerAttr`s which makes constant folding simple.
1191 if (Attribute arg = adaptor.getArg())
1192 return arg;
1193 return {};
1194}
1195
1196void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1197 MLIRContext *context) {
1198 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1199}
1200
1201//===----------------------------------------------------------------------===//
1202// FromExtentsOp
1203//===----------------------------------------------------------------------===//
1204
1205OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1207 for (Attribute attr : adaptor.getExtents()) {
1208 auto intAttr = llvm::dyn_cast_if_present<IntegerAttr>(attr);
1209 if (!intAttr)
1210 return nullptr;
1211 extents.push_back(intAttr.getInt());
1212 }
1213 Builder builder(getContext());
1214 return builder.getIndexTensorAttr(extents);
1215}
1216
1217//===----------------------------------------------------------------------===//
1218// FunctionLibraryOp
1219//===----------------------------------------------------------------------===//
1220
1221void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1222 StringRef name) {
1223 result.attributes.push_back(builder.getNamedAttr(
1225}
1226
1227FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1228 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1229 getMapping().get(op->getName().getIdentifier()));
1230 if (!attr)
1231 return nullptr;
1232 return lookupSymbol<FuncOp>(attr);
1233}
1234
1235ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1237 // Parse the op name.
1238 StringAttr nameAttr;
1240 result.attributes))
1241 return failure();
1242
1243 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1244 return failure();
1245
1246 auto *bodyRegion = result.addRegion();
1247 if (parser.parseRegion(*bodyRegion))
1248 return failure();
1249
1250 if (parser.parseKeyword("mapping"))
1251 return failure();
1252
1253 DictionaryAttr mappingAttr;
1254 if (parser.parseAttribute(mappingAttr,
1255 parser.getBuilder().getType<NoneType>(), "mapping",
1256 result.attributes))
1257 return failure();
1258 return success();
1259}
1260
1261void FunctionLibraryOp::print(OpAsmPrinter &p) {
1262 p << ' ';
1263 p.printSymbolName(getName());
1265 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1266 p << ' ';
1267 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1268 /*printBlockTerminators=*/false);
1269 p << " mapping ";
1270 p.printAttributeWithoutType(getMappingAttr());
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// FuncOp
1275//===----------------------------------------------------------------------===//
1276
1277FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1279 OpBuilder builder(location->getContext());
1280 OperationState state(location, getOperationName());
1281 FuncOp::build(builder, state, name, type, attrs);
1282 return cast<FuncOp>(Operation::create(state));
1283}
1284FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1286 SmallVector<NamedAttribute, 8> attrRef(attrs);
1287 return create(location, name, type, llvm::ArrayRef(attrRef));
1288}
1289FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1291 ArrayRef<DictionaryAttr> argAttrs) {
1292 FuncOp func = create(location, name, type, attrs);
1293 func.setAllArgAttrs(argAttrs);
1294 return func;
1295}
1296
1297void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1298 FunctionType type, ArrayRef<NamedAttribute> attrs,
1299 ArrayRef<DictionaryAttr> argAttrs) {
1300 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1301 builder.getStringAttr(name));
1302 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1303 TypeAttr::get(type));
1304 state.attributes.append(attrs.begin(), attrs.end());
1305 state.addRegion();
1306
1307 if (argAttrs.empty())
1308 return;
1309 assert(type.getNumInputs() == argAttrs.size());
1311 builder, state, argAttrs, /*resultAttrs=*/{},
1312 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1313}
1314
1315ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1316 auto buildFuncType =
1317 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1319 std::string &) { return builder.getFunctionType(argTypes, results); };
1320
1322 parser, result, /*allowVariadic=*/false,
1323 getFunctionTypeAttrName(result.name), buildFuncType,
1324 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1325}
1326
1327void FuncOp::print(OpAsmPrinter &p) {
1329 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1330 getArgAttrsAttrName(), getResAttrsAttrName());
1331}
1332
1333//===----------------------------------------------------------------------===//
1334// GetExtentOp
1335//===----------------------------------------------------------------------===//
1336
1337std::optional<int64_t> GetExtentOp::getConstantDim() {
1338 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1339 return constSizeOp.getValue().getLimitedValue();
1340 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1341 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1342 return std::nullopt;
1343}
1344
1345OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1346 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1347 if (!elements)
1348 return nullptr;
1349 std::optional<int64_t> dim = getConstantDim();
1350 if (!dim.has_value())
1351 return nullptr;
1352 if (dim.value() >= elements.getNumElements())
1353 return nullptr;
1354 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1355}
1356
1357void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1358 int64_t dim) {
1359 auto loc = result.location;
1360 auto dimAttr = builder.getIndexAttr(dim);
1361 if (llvm::isa<ShapeType>(shape.getType())) {
1362 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1363 build(builder, result, builder.getType<SizeType>(), shape, dim);
1364 } else {
1365 Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
1366 dimAttr);
1367 build(builder, result, builder.getIndexType(), shape, dim);
1368 }
1369}
1370
1371LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1372 MLIRContext *context, std::optional<Location> location,
1373 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1374 inferredReturnTypes.assign({IndexType::get(context)});
1375 return success();
1376}
1377
1378bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1379 TypeRange r) {
1380 // SizeType is compatible with IndexType.
1382}
1383
1384LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1385
1386//===----------------------------------------------------------------------===//
1387// IsBroadcastableOp
1388//===----------------------------------------------------------------------===//
1389
1390void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1391 MLIRContext *context) {
1392 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1393}
1394
1395OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1396 // Can always broadcast fewer than two shapes.
1397 if (adaptor.getShapes().size() < 2) {
1398 return BoolAttr::get(getContext(), true);
1399 }
1400
1401 return nullptr;
1402}
1403
1404//===----------------------------------------------------------------------===//
1405// MeetOp
1406//===----------------------------------------------------------------------===//
1407
1408LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1409 MLIRContext *context, std::optional<Location> location,
1410 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1411 if (adaptor.getOperands().empty())
1412 return failure();
1413
1414 auto isShapeType = [](Type arg) {
1415 if (llvm::isa<ShapeType>(arg))
1416 return true;
1417 return isExtentTensorType(arg);
1418 };
1419
1420 ValueRange::type_range types = adaptor.getOperands().getTypes();
1421 Type acc = types.front();
1422 for (auto t : drop_begin(types)) {
1423 Type l = acc, r = t;
1424 if (!llvm::isa<ShapeType, SizeType>(l))
1425 std::swap(l, r);
1426
1427 // Handle sizes, propagate error type if present.
1428 if (llvm::isa<SizeType>(l)) {
1429 if (llvm::isa<SizeType, IndexType>(r))
1430 acc = l;
1431 else
1432 return emitOptionalError(location, "requires all sizes or shapes");
1433 } else if (llvm::isa<IndexType>(l)) {
1434 if (llvm::isa<IndexType>(r))
1435 acc = r;
1436 else
1437 return emitOptionalError(location, "requires all sizes or shapes");
1438 } else if (llvm::isa<ShapeType>(l)) {
1439 // Handle shapes, propagate error type if present.
1440 if (isShapeType(r))
1441 acc = l;
1442 else
1443 return emitOptionalError(location, "requires all sizes or shapes");
1444 } else if (isExtentTensorType(l)) {
1445 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1446 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1447 if (ShapedType::isDynamic(rank1))
1448 acc = l;
1449 else if (ShapedType::isDynamic(rank2))
1450 acc = r;
1451 else if (rank1 != rank2)
1452 return emitOptionalError(location, "unequal shape cardinality");
1453 else
1454 acc = l;
1455 }
1456 }
1457 inferredReturnTypes.assign({acc});
1458 return success();
1459}
1460
1461bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1462 if (l.size() != 1 || r.size() != 1)
1463 return false;
1464 if (l == r)
1465 return true;
1466
1467 Type lhs = l.front();
1468 Type rhs = r.front();
1469
1470 if (!llvm::isa<ShapeType, SizeType>(lhs))
1471 std::swap(lhs, rhs);
1472
1473 if (llvm::isa<SizeType>(lhs))
1474 return llvm::isa<SizeType, IndexType>(rhs);
1475 if (llvm::isa<ShapeType>(lhs))
1476 return llvm::isa<ShapeType, TensorType>(rhs);
1477
1478 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1479 return true;
1480 return false;
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// RankOp
1485//===----------------------------------------------------------------------===//
1486
1487OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1488 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1489 if (!shape)
1490 return {};
1491 int64_t rank = shape.getNumElements();
1492 Builder builder(getContext());
1493 return builder.getIndexAttr(rank);
1494}
1495
1496/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1497/// Constant folding fails in cases where only the rank is constant, not the
1498/// shape itself.
1499/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1500///
1501/// Example:
1502///
1503/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1504/// %rank = shape.rank %shape
1505///
1506/// becomes
1507///
1508/// %rank = shape.const_size 3
1509
1510namespace {
1511struct RankShapeOfCanonicalizationPattern
1512 : public OpRewritePattern<shape::RankOp> {
1513 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1514
1515 LogicalResult matchAndRewrite(shape::RankOp op,
1516 PatternRewriter &rewriter) const override {
1517 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1518 if (!shapeOfOp)
1519 return failure();
1520 auto rankedTensorType =
1521 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1522 if (!rankedTensorType)
1523 return failure();
1524 int64_t rank = rankedTensorType.getRank();
1525 if (llvm::isa<IndexType>(op.getType())) {
1526 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1527 rank);
1528 } else if (llvm::isa<shape::SizeType>(op.getType())) {
1529 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1530 } else {
1531 return failure();
1532 }
1533 return success();
1534 }
1535};
1536} // namespace
1537
1538void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1539 MLIRContext *context) {
1540 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1541}
1542
1543LogicalResult mlir::shape::RankOp::inferReturnTypes(
1544 MLIRContext *context, std::optional<Location> location,
1545 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1546 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1547 inferredReturnTypes.assign({SizeType::get(context)});
1548 else
1549 inferredReturnTypes.assign({IndexType::get(context)});
1550 return success();
1551}
1552
1553bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1554 // SizeType is compatible with IndexType.
1556}
1557
1558LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1559
1560//===----------------------------------------------------------------------===//
1561// NumElementsOp
1562//===----------------------------------------------------------------------===//
1563
1564OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1565
1566 // Fold only when argument constant.
1567 Attribute shape = adaptor.getShape();
1568 if (!shape)
1569 return {};
1570
1571 APInt product(64, 1);
1572 for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1573 product *= value;
1574 Builder builder(getContext());
1575 return builder.getIndexAttr(product.getLimitedValue());
1576}
1577
1578LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1579 MLIRContext *context, std::optional<Location> location,
1580 NumElementsOp::Adaptor adaptor,
1581 SmallVectorImpl<Type> &inferredReturnTypes) {
1582 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1583 inferredReturnTypes.assign({SizeType::get(context)});
1584 else
1585 inferredReturnTypes.assign({IndexType::get(context)});
1586 return success();
1587}
1588
1589bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1590 TypeRange r) {
1591 // SizeType is compatible with IndexType.
1593}
1594
1595LogicalResult shape::NumElementsOp::verify() {
1596 return verifySizeOrIndexOp(*this);
1597}
1598
1599//===----------------------------------------------------------------------===//
1600// MaxOp
1601//===----------------------------------------------------------------------===//
1602
1603OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1604 // If operands are equal, just propagate one.
1605 if (getLhs() == getRhs())
1606 return getLhs();
1607 return nullptr;
1608}
1609
1610LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1611 MLIRContext *context, std::optional<Location> location,
1612 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1613 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1614 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1615 else
1616 inferredReturnTypes.assign({SizeType::get(context)});
1617 return success();
1618}
1619
1620bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1621 if (l.size() != 1 || r.size() != 1)
1622 return false;
1623 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1624 return true;
1625 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1626 return true;
1627 return false;
1628}
1629
1630//===----------------------------------------------------------------------===//
1631// MinOp
1632//===----------------------------------------------------------------------===//
1633
1634OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1635 // If operands are equal, just propagate one.
1636 if (getLhs() == getRhs())
1637 return getLhs();
1638 return nullptr;
1639}
1640
1641LogicalResult mlir::shape::MinOp::inferReturnTypes(
1642 MLIRContext *context, std::optional<Location> location,
1643 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1644 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1645 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1646 else
1647 inferredReturnTypes.assign({SizeType::get(context)});
1648 return success();
1649}
1650
1651bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1652 if (l.size() != 1 || r.size() != 1)
1653 return false;
1654 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1655 return true;
1656 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1657 return true;
1658 return false;
1659}
1660
1661//===----------------------------------------------------------------------===//
1662// MulOp
1663//===----------------------------------------------------------------------===//
1664
1665OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1666 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1667 if (!lhs)
1668 return nullptr;
1669 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1670 if (!rhs)
1671 return nullptr;
1672 APInt folded = lhs.getValue() * rhs.getValue();
1673 Type indexTy = IndexType::get(getContext());
1674 return IntegerAttr::get(indexTy, folded);
1675}
1676
1677LogicalResult mlir::shape::MulOp::inferReturnTypes(
1678 MLIRContext *context, std::optional<Location> location,
1679 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1680 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1681 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1682 inferredReturnTypes.assign({SizeType::get(context)});
1683 else
1684 inferredReturnTypes.assign({IndexType::get(context)});
1685 return success();
1686}
1687
1688bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1689 // SizeType is compatible with IndexType.
1691}
1692
1693LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1694
1695//===----------------------------------------------------------------------===//
1696// ShapeOfOp
1697//===----------------------------------------------------------------------===//
1698
1699namespace {
1700/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1701struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1702 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1703
1704 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1705 PatternRewriter &rewriter) const override {
1706 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1707 if (!type || !type.hasStaticShape())
1708 return failure();
1709
1710 Type resultType = op.getResult().getType();
1711 Location loc = op.getLoc();
1712 Type constResType =
1713 isa<ShapeType>(resultType)
1714 ? resultType
1715 : RankedTensorType::get({type.getRank()}, rewriter.getIndexType());
1716 Value constShape =
1717 ConstShapeOp::create(rewriter, loc, constResType,
1718 rewriter.getIndexTensorAttr(type.getShape()))
1719 .getResult();
1720 if (constShape.getType() != resultType)
1721 constShape =
1722 tensor::CastOp::create(rewriter, loc, resultType, constShape);
1723 rewriter.replaceOp(op, constShape);
1724 return success();
1725 }
1726};
1727
1728// Canonicalize
1729//
1730// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1731// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1732//
1733// to
1734//
1735// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1736// %1 = %shape
1737//
1738struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1739 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1740
1741 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1742 PatternRewriter &rewriter) const override {
1743 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1744 if (!tensorReshapeOp)
1745 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1746 if (!isa<TensorType>(op.getType()))
1747 return rewriter.notifyMatchFailure(op, "result is not a tensor");
1748
1749 // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1750 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1751 // formed IR, it may not be identical (dynamically vs statically shaped),
1752 // in which case it needs to be cast first using 'tensor.cast'.
1753 // Additionally, it may not have identical element type (i32 vs index)
1754 // while it has identical shaped type (dynamic vs static), in which case it
1755 // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1756 // op result must be shape or extent tensor.
1757 Value shape = tensorReshapeOp.getShape();
1758
1759 auto opTensorTy = cast<RankedTensorType>(op.getType());
1760 auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1761
1762 if (opTensorTy != shapeTensorTy) {
1763 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1764 shape =
1765 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1766 else if (!isExtentTensorType(shapeTensorTy))
1767 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1768 shape);
1769 }
1770
1771 rewriter.replaceOp(op, shape);
1772 return success();
1773 }
1774};
1775
1776// Canonicalize
1777// ```
1778// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1779// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1780// ```
1781// to
1782// ```
1783// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1784// ```
1785struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1786 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1787
1788 LogicalResult matchAndRewrite(tensor::CastOp op,
1789 PatternRewriter &rewriter) const override {
1790 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1791 if (!ty || ty.getRank() != 1)
1792 return failure();
1793
1794 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1795 if (!shapeOfOp)
1796 return failure();
1797
1798 // Argument type must be ranked and must not conflict.
1799 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1800 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1801 return failure();
1802
1803 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1804 return success();
1805 }
1806};
1807} // namespace
1808
1809void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1810 MLIRContext *context) {
1811 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1812 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1813 context);
1814}
1815
1816LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1817 MLIRContext *context, std::optional<Location> location,
1818 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1819 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1820 inferredReturnTypes.assign({ShapeType::get(context)});
1821 else {
1822 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1823 int64_t rank =
1824 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1825 Type indexTy = IndexType::get(context);
1826 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1827 inferredReturnTypes.assign({extentTensorTy});
1828 }
1829 return success();
1830}
1831
1832bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1833 if (l.size() != 1 || r.size() != 1)
1834 return false;
1835 if (l == r)
1836 return true;
1837
1838 Type lhs = l.front();
1839 Type rhs = r.front();
1840
1841 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1842 !llvm::isa<ShapeType, ShapedType>(rhs))
1843 return false;
1844
1845 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1846 // Shape type is compatible with all other valid return types.
1847 return true;
1848
1849 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1850 return true;
1851 return false;
1852}
1853
1854LogicalResult shape::ShapeOfOp::verify() {
1855 return verifyShapeOrExtentTensorOp(*this);
1856}
1857
1858//===----------------------------------------------------------------------===//
1859// SizeToIndexOp
1860//===----------------------------------------------------------------------===//
1861
1862OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1863 // Constant values of both types, `shape.size` and `index`, are represented as
1864 // `IntegerAttr`s which makes constant folding simple.
1865 if (Attribute arg = adaptor.getArg())
1866 return arg;
1867 return OpFoldResult();
1868}
1869
1870void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1871 MLIRContext *context) {
1872 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1873}
1874
1875bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1876 if (inputs.size() != 1 || outputs.size() != 1)
1877 return false;
1878 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1879 llvm::isa<IndexType>(outputs[0]);
1880}
1881
1882//===----------------------------------------------------------------------===//
1883// YieldOp
1884//===----------------------------------------------------------------------===//
1885
1886LogicalResult shape::YieldOp::verify() {
1887 auto *parentOp = (*this)->getParentOp();
1888 auto results = parentOp->getResults();
1889 auto operands = getOperands();
1890
1891 if (parentOp->getNumResults() != getNumOperands())
1892 return emitOpError() << "number of operands does not match number of "
1893 "results of its parent";
1894 for (auto e : llvm::zip(results, operands))
1895 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1896 return emitOpError() << "types mismatch between yield op and its parent";
1897
1898 return success();
1899}
1900
1901//===----------------------------------------------------------------------===//
1902// SplitAtOp
1903//===----------------------------------------------------------------------===//
1904
1905LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1907 if (!adaptor.getOperand() || !adaptor.getIndex())
1908 return failure();
1909 auto shapeVec = llvm::to_vector<6>(
1910 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1911 auto shape = llvm::ArrayRef(shapeVec);
1912 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1913 // Verify that the split point is in the correct range.
1914 // TODO: Constant fold to an "error".
1915 int64_t rank = shape.size();
1916 if (-rank > splitPoint || splitPoint > rank)
1917 return failure();
1918 if (splitPoint < 0)
1919 splitPoint += shape.size();
1920 Builder builder(adaptor.getOperand().getContext());
1921 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1922 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1923 return success();
1924}
1925
1926//===----------------------------------------------------------------------===//
1927// ToExtentTensorOp
1928//===----------------------------------------------------------------------===//
1929
1930OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1931 if (!adaptor.getInput())
1932 return OpFoldResult();
1933 Builder builder(getContext());
1934 auto shape = llvm::to_vector<6>(
1935 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1936 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1937 builder.getIndexType());
1938 return DenseIntElementsAttr::get(type, shape);
1939}
1940
1941bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1942 if (inputs.size() != 1 || outputs.size() != 1)
1943 return false;
1944 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1945 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1946 inputTensor.getRank() != 1)
1947 return false;
1948 } else if (!llvm::isa<ShapeType>(inputs[0])) {
1949 return false;
1950 }
1951
1952 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1953 return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1954}
1955
1956//===----------------------------------------------------------------------===//
1957// ReduceOp
1958//===----------------------------------------------------------------------===//
1959
1960void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1961 ValueRange initVals) {
1962 OpBuilder::InsertionGuard g(builder);
1963 result.addOperands(shape);
1964 result.addOperands(initVals);
1965
1966 Region *bodyRegion = result.addRegion();
1967 Block *bodyBlock = builder.createBlock(
1968 bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1969
1970 Type elementType;
1971 if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1972 elementType = tensorType.getElementType();
1973 else
1974 elementType = SizeType::get(builder.getContext());
1975 bodyBlock->addArgument(elementType, shape.getLoc());
1976
1977 for (Value initVal : initVals) {
1978 bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1979 result.addTypes(initVal.getType());
1980 }
1981}
1982
1983LogicalResult ReduceOp::verify() {
1984 // Verify block arg types.
1985 Block &block = getRegion().front();
1986
1987 // The block takes index, extent, and aggregated values as arguments.
1988 auto blockArgsCount = getInitVals().size() + 2;
1989 if (block.getNumArguments() != blockArgsCount)
1990 return emitOpError() << "ReduceOp body is expected to have "
1991 << blockArgsCount << " arguments";
1992
1993 // The first block argument is the index and must always be of type `index`.
1994 if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1995 return emitOpError(
1996 "argument 0 of ReduceOp body is expected to be of IndexType");
1997
1998 // The second block argument is the extent and must be of type `size` or
1999 // `index`, depending on whether the reduce operation is applied to a shape or
2000 // to an extent tensor.
2001 Type extentTy = block.getArgument(1).getType();
2002 if (llvm::isa<ShapeType>(getShape().getType())) {
2003 if (!llvm::isa<SizeType>(extentTy))
2004 return emitOpError("argument 1 of ReduceOp body is expected to be of "
2005 "SizeType if the ReduceOp operates on a ShapeType");
2006 } else {
2007 if (!llvm::isa<IndexType>(extentTy))
2008 return emitOpError(
2009 "argument 1 of ReduceOp body is expected to be of IndexType if the "
2010 "ReduceOp operates on an extent tensor");
2011 }
2012
2013 for (const auto &type : llvm::enumerate(getInitVals()))
2014 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2015 return emitOpError() << "type mismatch between argument "
2016 << type.index() + 2
2017 << " of ReduceOp body and initial value "
2018 << type.index();
2019 return success();
2020}
2021
2022ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2023 // Parse operands.
2025 Type shapeOrExtentTensorType;
2026 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2028 parser.parseColonType(shapeOrExtentTensorType) ||
2029 parser.parseOptionalArrowTypeList(result.types))
2030 return failure();
2031
2032 // Resolve operands.
2033 auto initVals = llvm::ArrayRef(operands).drop_front();
2034 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2035 result.operands) ||
2036 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2037 result.operands))
2038 return failure();
2039
2040 // Parse the body.
2041 Region *body = result.addRegion();
2042 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2043 return failure();
2044
2045 // Parse attributes.
2046 if (parser.parseOptionalAttrDict(result.attributes))
2047 return failure();
2048
2049 return success();
2050}
2051
2052void ReduceOp::print(OpAsmPrinter &p) {
2053 p << '(' << getShape() << ", " << getInitVals()
2054 << ") : " << getShape().getType();
2055 p.printOptionalArrowTypeList(getResultTypes());
2056 p << ' ';
2057 p.printRegion(getRegion());
2058 p.printOptionalAttrDict((*this)->getAttrs());
2059}
2060
2061#define GET_OP_CLASSES
2062#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2063
2064#define GET_TYPEDEF_CLASSES
2065#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition Shape.cpp:66
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition Shape.cpp:978
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition Shape.cpp:83
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition Shape.cpp:96
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition Shape.cpp:71
static int64_t product(ArrayRef< int64_t > vals)
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:80
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:93
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:266
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:197
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:98
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:447
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:438
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:444
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
iterator_range< dialect_attr_iterator > dialect_attr_range
Definition Operation.h:634
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
ValueTypeRange< ValueRange > type_range
Definition ValueRange.h:418
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
bool isExtentTensorType(Type)
Definition Shape.cpp:44
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition Shape.cpp:49
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.