MLIR 23.0.0git
Shape.cpp
Go to the documentation of this file.
1//===- Shape.cpp - MLIR Shape Operations ----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
15#include "mlir/Dialect/Traits.h"
17#include "mlir/IR/Builders.h"
20#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/SetOperations.h"
26#include "llvm/ADT/SmallVectorExtras.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/raw_ostream.h"
29#include <utility>
30
31using namespace mlir;
32using namespace mlir::shape;
33
34#include "mlir/Dialect/Shape/IR/ShapeOpsDialect.cpp.inc"
35
36namespace {
37#include "ShapeCanonicalization.inc"
38} // namespace
39
40RankedTensorType shape::getExtentTensorType(MLIRContext *ctx, int64_t rank) {
41 return RankedTensorType::get({rank}, IndexType::get(ctx));
42}
43
45 auto ranked = llvm::dyn_cast<RankedTensorType>(type);
46 return ranked && ranked.getRank() == 1 && ranked.getElementType().isIndex();
47}
48
49LogicalResult shape::getShapeVec(Value input,
50 SmallVectorImpl<int64_t> &shapeValues) {
51 if (auto inputOp = input.getDefiningOp<ShapeOfOp>()) {
52 auto type = llvm::cast<ShapedType>(inputOp.getArg().getType());
53 if (!type.hasRank())
54 return failure();
55 llvm::append_range(shapeValues, type.getShape());
56 return success();
57 }
59 if (matchPattern(input, m_Constant(&attr))) {
60 llvm::append_range(shapeValues, attr.getValues<int64_t>());
61 return success();
62 }
63 return failure();
64}
65
66static bool isErrorPropagationPossible(TypeRange operandTypes) {
67 return llvm::any_of(operandTypes,
68 llvm::IsaPred<SizeType, ShapeType, ValueShapeType>);
69}
70
71static LogicalResult verifySizeOrIndexOp(Operation *op) {
72 assert(op != nullptr && op->getNumResults() == 1);
73 Type resultTy = op->getResultTypes().front();
75 if (!llvm::isa<SizeType>(resultTy))
76 return op->emitOpError()
77 << "if at least one of the operands can hold error values then "
78 "the result must be of type `size` to propagate them";
79 }
80 return success();
81}
82
83static LogicalResult verifyShapeOrExtentTensorOp(Operation *op) {
84 assert(op != nullptr && op->getNumResults() == 1);
85 Type resultTy = op->getResultTypes().front();
87 if (!llvm::isa<ShapeType>(resultTy))
88 return op->emitOpError()
89 << "if at least one of the operands can hold error values then "
90 "the result must be of type `shape` to propagate them";
91 }
92 return success();
93}
94
95template <typename... Ty>
96static bool eachHasOnlyOneOfTypes(TypeRange typeRange) {
97 return typeRange.size() == 1 && llvm::isa<Ty...>(typeRange.front());
98}
99
100template <typename... Ty, typename... ranges>
101static bool eachHasOnlyOneOfTypes(TypeRange l, ranges... rs) {
102 return eachHasOnlyOneOfTypes<Ty...>(l) && eachHasOnlyOneOfTypes<Ty...>(rs...);
103}
104
105//===----------------------------------------------------------------------===//
106// InlinerInterface
107//===----------------------------------------------------------------------===//
108
109namespace {
110/// This class defines the interface for inlining shape dialect ops.
111struct ShapeInlinerInterface : public DialectInlinerInterface {
112 using DialectInlinerInterface::DialectInlinerInterface;
113
114 // Returns true if the given region 'src' can be inlined into the region
115 // 'dest' that is attached to an operation registered to the current dialect.
116 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
117 IRMapping &) const final {
118 return true;
119 }
120
121 // Returns true if the given operation 'op', that is registered to this
122 // dialect, can be inlined into the region 'dest' that is attached to an
123 // operation registered to the current dialect.
124 bool isLegalToInline(Operation *op, Region *dest, bool wouldBeCloned,
125 IRMapping &) const final {
126 return true;
127 }
128};
129} // namespace
130
131void ShapeDialect::initialize() {
132 addOperations<
133#define GET_OP_LIST
134#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
135 >();
136 addTypes<
137#define GET_TYPEDEF_LIST
138#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
139 >();
140 addInterfaces<ShapeInlinerInterface>();
141 // Allow unknown operations during prototyping and testing. As the dialect is
142 // still evolving it makes it simple to start with an unregistered ops and
143 // try different variants before actually defining the op.
144 allowUnknownOperations();
145 declarePromisedInterfaces<bufferization::BufferizableOpInterface, AssumingOp,
146 AssumingYieldOp>();
147}
148
149Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
150 Attribute value, Type type,
151 Location loc) {
152 if (auto poison = dyn_cast<ub::PoisonAttr>(value))
153 return ub::PoisonOp::create(builder, loc, type, poison);
154
155 if (llvm::isa<ShapeType>(type) || isExtentTensorType(type))
156 return ConstShapeOp::create(builder, loc, type,
157 llvm::cast<DenseIntElementsAttr>(value));
158 if (llvm::isa<SizeType>(type))
159 return ConstSizeOp::create(builder, loc, type,
160 llvm::cast<IntegerAttr>(value));
161 if (llvm::isa<WitnessType>(type))
162 return ConstWitnessOp::create(builder, loc, type,
163 llvm::cast<BoolAttr>(value));
164
165 return arith::ConstantOp::materialize(builder, value, type, loc);
166}
167
168LogicalResult ShapeDialect::verifyOperationAttribute(Operation *op,
169 NamedAttribute attribute) {
170 // Verify shape.lib attribute.
171 if (attribute.getName() == "shape.lib") {
172 if (!op->hasTrait<OpTrait::SymbolTable>())
173 return op->emitError(
174 "shape.lib attribute may only be on op implementing SymbolTable");
175
176 if (auto symbolRef = llvm::dyn_cast<SymbolRefAttr>(attribute.getValue())) {
177 auto *symbol = SymbolTable::lookupSymbolIn(op, symbolRef);
178 if (!symbol)
179 return op->emitError("shape function library ")
180 << symbolRef << " not found";
181 return isa<shape::FunctionLibraryOp>(symbol)
182 ? success()
183 : op->emitError()
184 << symbolRef << " required to be shape function library";
185 }
186
187 if (auto arr = llvm::dyn_cast<ArrayAttr>(attribute.getValue())) {
188 // Verify all entries are function libraries and mappings in libraries
189 // refer to unique ops.
191 for (auto it : arr) {
192 if (!llvm::isa<SymbolRefAttr>(it))
193 return op->emitError(
194 "only SymbolRefAttr allowed in shape.lib attribute array");
195
196 auto shapeFnLib = dyn_cast<shape::FunctionLibraryOp>(
197 SymbolTable::lookupSymbolIn(op, llvm::cast<SymbolRefAttr>(it)));
198 if (!shapeFnLib)
199 return op->emitError()
200 << it << " does not refer to FunctionLibraryOp";
201 for (auto mapping : shapeFnLib.getMapping()) {
202 if (!key.insert(mapping.getName()).second) {
203 return op->emitError("only one op to shape mapping allowed, found "
204 "multiple for `")
205 << mapping.getName() << "`";
206 }
207 }
208 }
209 return success();
210 }
211
212 return op->emitError("only SymbolRefAttr or array of SymbolRefAttrs "
213 "allowed as shape.lib attribute");
214 }
215 return success();
216}
217
218//===----------------------------------------------------------------------===//
219// AnyOp
220//===----------------------------------------------------------------------===//
221
222// TODO: Canonicalization should be implemented for shapes that can be
223// determined through mixtures of the known dimensions of the inputs.
224OpFoldResult AnyOp::fold(FoldAdaptor adaptor) {
225 // Only the last operand is checked because AnyOp is commutative.
226 if (adaptor.getInputs().back())
227 return adaptor.getInputs().back();
228
229 return nullptr;
230}
231
232//===----------------------------------------------------------------------===//
233// AssumingOp
234//===----------------------------------------------------------------------===//
235
236ParseResult AssumingOp::parse(OpAsmParser &parser, OperationState &result) {
237 result.regions.reserve(1);
238 Region *doRegion = result.addRegion();
239
240 auto &builder = parser.getBuilder();
242 if (parser.parseOperand(cond) ||
243 parser.resolveOperand(cond, builder.getType<WitnessType>(),
244 result.operands))
245 return failure();
246
247 // Parse optional results type list.
248 if (parser.parseOptionalArrowTypeList(result.types))
249 return failure();
250
251 // Parse the region and add a terminator if elided.
252 if (parser.parseRegion(*doRegion, /*arguments=*/{}, /*argTypes=*/{}))
253 return failure();
254 AssumingOp::ensureTerminator(*doRegion, parser.getBuilder(), result.location);
255
256 // Parse the optional attribute list.
257 if (parser.parseOptionalAttrDict(result.attributes))
258 return failure();
259 return success();
260}
261
262void AssumingOp::print(OpAsmPrinter &p) {
263 bool yieldsResults = !getResults().empty();
264
265 p << " " << getWitness();
266 if (yieldsResults)
267 p << " -> (" << getResultTypes() << ")";
268 p << ' ';
269 p.printRegion(getDoRegion(),
270 /*printEntryBlockArgs=*/false,
271 /*printBlockTerminators=*/yieldsResults);
272 p.printOptionalAttrDict((*this)->getAttrs());
273}
274
275namespace {
276// Removes AssumingOp with a passing witness and inlines the region.
277struct AssumingWithTrue : public OpRewritePattern<AssumingOp> {
278 using OpRewritePattern<AssumingOp>::OpRewritePattern;
279
280 LogicalResult matchAndRewrite(AssumingOp op,
281 PatternRewriter &rewriter) const override {
282 auto witness = op.getWitness().getDefiningOp<ConstWitnessOp>();
283 if (!witness || !witness.getPassingAttr())
284 return failure();
285
286 AssumingOp::inlineRegionIntoParent(op, rewriter);
287 return success();
288 }
289};
290
291struct AssumingOpRemoveUnusedResults : public OpRewritePattern<AssumingOp> {
292 using OpRewritePattern<AssumingOp>::OpRewritePattern;
293
294 LogicalResult matchAndRewrite(AssumingOp op,
295 PatternRewriter &rewriter) const override {
296 Block *body = op.getBody();
297 auto yieldOp = llvm::cast<AssumingYieldOp>(body->getTerminator());
298
299 // Find used values.
300 SmallVector<Value, 4> newYieldOperands;
301 for (auto [opResult, yieldOperand] :
302 llvm::zip(op.getResults(), yieldOp.getOperands())) {
303 if (!opResult.getUses().empty()) {
304 newYieldOperands.push_back(yieldOperand);
305 }
306 }
307
308 // Rewrite only if redundant results exist.
309 if (newYieldOperands.size() == yieldOp->getNumOperands())
310 return failure();
311
312 // Replace yield op in the old assuming op's body and move the entire region
313 // to the new assuming op.
314 rewriter.setInsertionPointToEnd(body);
315 auto newYieldOp =
316 rewriter.replaceOpWithNewOp<AssumingYieldOp>(yieldOp, newYieldOperands);
317 rewriter.setInsertionPoint(op);
318 auto newOp = AssumingOp::create(
319 rewriter, op.getLoc(), newYieldOp->getOperandTypes(), op.getWitness());
320 newOp.getDoRegion().takeBody(op.getDoRegion());
321
322 // Use the new results to replace the previously used ones.
323 SmallVector<Value, 4> replacementValues;
324 auto src = newOp.getResults().begin();
325 for (auto it : op.getResults()) {
326 if (it.getUses().empty())
327 replacementValues.push_back(nullptr);
328 else
329 replacementValues.push_back(*src++);
330 }
331 rewriter.replaceOp(op, replacementValues);
332 return success();
333 }
334};
335} // namespace
336
337void AssumingOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
338 MLIRContext *context) {
339 patterns.add<AssumingOpRemoveUnusedResults, AssumingWithTrue>(context);
340}
341
342// See RegionBranchOpInterface in Interfaces/ControlFlowInterfaces.td
343void AssumingOp::getSuccessorRegions(
345 // AssumingOp has unconditional control flow into the region and back to the
346 // parent, so return the correct RegionSuccessor purely based on the index
347 // being None or 0.
348 if (!point.isParent()) {
349 regions.push_back(RegionSuccessor::parent());
350 return;
351 }
352
353 regions.push_back(RegionSuccessor(&getDoRegion()));
354}
355
356ValueRange AssumingOp::getSuccessorInputs(RegionSuccessor successor) {
357 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
358}
359
360void AssumingOp::inlineRegionIntoParent(AssumingOp &op,
361 PatternRewriter &rewriter) {
362 auto *blockBeforeAssuming = rewriter.getInsertionBlock();
363 auto *assumingBlock = op.getBody();
364 auto initPosition = rewriter.getInsertionPoint();
365 auto *blockAfterAssuming =
366 rewriter.splitBlock(blockBeforeAssuming, initPosition);
367
368 // Remove the AssumingOp and AssumingYieldOp.
369 auto &yieldOp = assumingBlock->back();
370 rewriter.inlineRegionBefore(op.getDoRegion(), blockAfterAssuming);
371 rewriter.replaceOp(op, yieldOp.getOperands());
372 rewriter.eraseOp(&yieldOp);
373
374 // Merge blocks together as there was no branching behavior from the
375 // AssumingOp.
376 rewriter.mergeBlocks(assumingBlock, blockBeforeAssuming);
377 rewriter.mergeBlocks(blockAfterAssuming, blockBeforeAssuming);
378}
379
380void AssumingOp::build(
381 OpBuilder &builder, OperationState &result, Value witness,
383 OpBuilder::InsertionGuard g(builder);
384
385 result.addOperands(witness);
386 Region *bodyRegion = result.addRegion();
387 builder.createBlock(bodyRegion);
388
389 // Build body.
390 SmallVector<Value, 2> yieldValues = bodyBuilder(builder, result.location);
391 AssumingYieldOp::create(builder, result.location, yieldValues);
392
393 SmallVector<Type, 2> assumingTypes;
394 for (Value v : yieldValues)
395 assumingTypes.push_back(v.getType());
396 result.addTypes(assumingTypes);
397}
398
399//===----------------------------------------------------------------------===//
400// AddOp
401//===----------------------------------------------------------------------===//
402
403LogicalResult mlir::shape::AddOp::inferReturnTypes(
404 MLIRContext *context, std::optional<Location> location,
405 AddOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
406 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
407 llvm::isa<SizeType>(adaptor.getRhs().getType()))
408 inferredReturnTypes.assign({SizeType::get(context)});
409 else
410 inferredReturnTypes.assign({IndexType::get(context)});
411 return success();
412}
413
414bool mlir::shape::AddOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
415 // SizeType is compatible with IndexType.
417}
418
419OpFoldResult mlir::shape::AddOp::fold(FoldAdaptor adaptor) {
420 // add(x, 0) -> x
421 if (matchPattern(getRhs(), m_Zero()))
422 return getLhs();
423
425 adaptor.getOperands(),
426 [](APInt a, const APInt &b) { return std::move(a) + b; });
427}
428
429LogicalResult shape::AddOp::verify() { return verifySizeOrIndexOp(*this); }
430
431//===----------------------------------------------------------------------===//
432// AssumingAllOp
433//===----------------------------------------------------------------------===//
434
435namespace {
436
437// Merge multiple `shape.assuming_all` operations together.
438//
439// %0 = shape.assuming_all %w0, %w1
440// %1 = shape.assuming_all %w2, %0
441//
442// to:
443//
444// %0 = shape.assuming_all %w0, %w2, %w2
445struct MergeAssumingAllOps : public OpRewritePattern<AssumingAllOp> {
446 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
447
448 LogicalResult matchAndRewrite(AssumingAllOp op,
449 PatternRewriter &rewriter) const override {
450 SmallVector<Value> operands;
451
452 for (Value operand : op.getInputs()) {
453 if (auto assumeAll = operand.getDefiningOp<AssumingAllOp>())
454 operands.append(assumeAll.operand_begin(), assumeAll->operand_end());
455 else
456 operands.push_back(operand);
457 }
458
459 // We didn't find any other `assuming_all` ops to merge with.
460 if (operands.size() == op.getNumOperands())
461 return failure();
462
463 // Replace with a new `assuming_all` operation with merged constraints.
464 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, operands);
465 return success();
466 }
467};
468
469// Eliminate `cstr_broadcastable` operands from `assuming_all` operation that
470// are subsumed by others.
471//
472// %0 = shape.cstr_broadcastable %shape0, %shape1
473// %1 = shape.cstr_broadcastable %shape0, %shape1, %shape2
474//
475// %2 = shape.cstr_broadcastable %shape3, %shape4
476// %3 = shape.cstr_broadcastable %shape3, %shape4, %shape5
477//
478// %4 = shape.assuming_all %0, %1, %2, %3
479//
480// to:
481//
482// %0 = shape.cstr_broadcastable %shape0, %shape1, %shape2
483// %1 = shape.cstr_broadcastable %shape3, %shape4, %shape5
484// %2 = shape.assuming_all %0, %1
485//
486// In this example if shapes [0, 1, 2] are broadcastable, then it means that
487// shapes [0, 1] are broadcastable too, and can be removed from the list of
488// constraints. If shapes [0, 1, 2] are not broadcastable, then it doesn't
489// matter if shapes [0, 1] are broadcastable (same for shapes [3, 4, 5]).
490struct AssumingAllOfCstrBroadcastable : public OpRewritePattern<AssumingAllOp> {
491 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
492
493 LogicalResult matchAndRewrite(AssumingAllOp op,
494 PatternRewriter &rewriter) const override {
495 // Collect all `CstrBroadcastableOp` operands first.
497 for (Value operand : op.getInputs()) {
498 // TODO: Apply this optimization if some of the witnesses are not
499 // produced by the `cstr_broadcastable`.
500 auto broadcastable = operand.getDefiningOp<CstrBroadcastableOp>();
501 if (!broadcastable)
502 return failure();
503
504 operands.insert(broadcastable);
505 }
506
507 // Skip trivial `assuming_all` operations.
508 if (operands.size() <= 1)
509 return failure();
510
511 // Collect shapes checked by `cstr_broadcastable` operands.
512 SmallVector<std::pair<CstrBroadcastableOp, DenseSet<Value>>> shapes;
513 for (auto cstr : operands) {
514 DenseSet<Value> shapesSet(cstr->operand_begin(), cstr->operand_end());
515 shapes.emplace_back(cstr, std::move(shapesSet));
516 }
517
518 // Sort by the number of shape operands (larger to smaller).
519 llvm::sort(shapes, [](auto a, auto b) {
520 return a.first.getNumOperands() > b.first.getNumOperands();
521 });
522
523 // We start from the `cst_broadcastable` operations with largest number of
524 // shape operands, and remove redundant `cst_broadcastable` operations. We
525 // do this until we find a set of `cst_broadcastable` operations with
526 // non-overlapping constraints.
527 SmallVector<CstrBroadcastableOp> markedForErase;
528
529 for (unsigned i = 0; i < shapes.size(); ++i) {
530 auto isSubset = [&](auto pair) {
531 return llvm::set_is_subset(pair.second, shapes[i].second);
532 };
533
534 // Keep redundant `cstr_broadcastable` operations to be erased.
535 auto *it = std::remove_if(shapes.begin() + i + 1, shapes.end(), isSubset);
536 for (auto *it0 = it; it0 < shapes.end(); ++it0)
537 markedForErase.push_back(it0->first);
538 shapes.erase(it, shapes.end());
539 }
540
541 // We didn't find any operands that could be removed.
542 if (markedForErase.empty())
543 return failure();
544
545 // Collect non-overlapping `cst_broadcastable` constraints.
546 SmallVector<Value> uniqueConstraints;
547 for (auto &shape : shapes)
548 uniqueConstraints.push_back(shape.first.getResult());
549
550 // Replace with a new `assuming_all` operation ...
551 rewriter.replaceOpWithNewOp<AssumingAllOp>(op, uniqueConstraints);
552
553 // ... and maybe erase `cstr_broadcastable` ops without uses.
554 for (auto &op : markedForErase)
555 if (op->use_empty())
556 rewriter.eraseOp(op);
557
558 return success();
559 }
560};
561
562struct AssumingAllToCstrEqCanonicalization
563 : public OpRewritePattern<AssumingAllOp> {
564 using OpRewritePattern<AssumingAllOp>::OpRewritePattern;
565
566 LogicalResult matchAndRewrite(AssumingAllOp op,
567 PatternRewriter &rewriter) const override {
568 SmallVector<Value, 8> shapes;
569 for (Value w : op.getInputs()) {
570 auto cstrEqOp = w.getDefiningOp<CstrEqOp>();
571 if (!cstrEqOp)
572 return failure();
573 bool disjointShapes = llvm::none_of(cstrEqOp.getShapes(), [&](Value s) {
574 return llvm::is_contained(shapes, s);
575 });
576 if (!shapes.empty() && !cstrEqOp.getShapes().empty() && disjointShapes)
577 return failure();
578 shapes.append(cstrEqOp.getShapes().begin(), cstrEqOp.getShapes().end());
579 }
580 rewriter.replaceOpWithNewOp<CstrEqOp>(op, shapes);
581 return success();
582 }
583};
584
585template <typename OpTy>
586struct RemoveDuplicateOperandsPattern : public OpRewritePattern<OpTy> {
587 using OpRewritePattern<OpTy>::OpRewritePattern;
588
589 LogicalResult matchAndRewrite(OpTy op,
590 PatternRewriter &rewriter) const override {
591 // Find unique operands.
592 SetVector<Value> unique(op.operand_begin(), op.operand_end());
593
594 // Reduce op to equivalent with unique operands.
595 if (unique.size() < op.getNumOperands()) {
596 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(),
597 unique.takeVector(), op->getAttrs());
598 return success();
599 }
600
601 return failure();
602 }
603};
604} // namespace
605
606void AssumingAllOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
607 MLIRContext *context) {
609 .add<MergeAssumingAllOps, AssumingAllOneOp,
610 AssumingAllOfCstrBroadcastable, AssumingAllToCstrEqCanonicalization,
611 RemoveDuplicateOperandsPattern<AssumingAllOp>>(context);
612}
613
614OpFoldResult AssumingAllOp::fold(FoldAdaptor adaptor) {
615 // Iterate in reverse to first handle all constant operands. They are
616 // guaranteed to be the tail of the inputs because this is commutative.
617 for (int idx = adaptor.getInputs().size() - 1; idx >= 0; idx--) {
618 Attribute a = adaptor.getInputs()[idx];
619 // Cannot fold if any inputs are not constant;
620 if (!a)
621 return nullptr;
622
623 // We do not need to keep statically known values after handling them in
624 // this method.
625 getOperation()->eraseOperand(idx);
626
627 // Always false if any input is statically known false
628 if (!llvm::cast<BoolAttr>(a).getValue())
629 return a;
630 }
631 // If this is reached, all inputs were statically known passing.
632 return BoolAttr::get(getContext(), true);
633}
634
635LogicalResult AssumingAllOp::verify() {
636 // Ensure that AssumingAllOp contains at least one operand
637 if (getNumOperands() == 0)
638 return emitOpError("no operands specified");
639
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// BroadcastOp
645//===----------------------------------------------------------------------===//
646
647OpFoldResult BroadcastOp::fold(FoldAdaptor adaptor) {
648 if (getShapes().size() == 1) {
649 // Otherwise, we need a cast which would be a canonicalization, not folding.
650 if (getShapes().front().getType() != getType())
651 return nullptr;
652 return getShapes().front();
653 }
654
655 if (!adaptor.getShapes().front())
656 return nullptr;
657
658 SmallVector<int64_t, 6> resultShape(
659 llvm::cast<DenseIntElementsAttr>(adaptor.getShapes().front())
660 .getValues<int64_t>());
661
662 for (auto next : adaptor.getShapes().drop_front()) {
663 if (!next)
664 return nullptr;
665 auto nextShape = llvm::to_vector<6>(
666 llvm::cast<DenseIntElementsAttr>(next).getValues<int64_t>());
667
669 // If the shapes are not compatible, we can't fold it.
670 // TODO: Fold to an "error".
671 if (!OpTrait::util::getBroadcastedShape(resultShape, nextShape, tmpShape))
672 return nullptr;
673
674 resultShape.clear();
675 std::copy(tmpShape.begin(), tmpShape.end(),
676 std::back_inserter(resultShape));
677 }
678
679 Builder builder(getContext());
680 return builder.getIndexTensorAttr(resultShape);
681}
682
683LogicalResult BroadcastOp::verify() {
684 return verifyShapeOrExtentTensorOp(*this);
685}
686
687namespace {
688template <typename OpTy>
689struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern<OpTy> {
690 using OpRewritePattern<OpTy>::OpRewritePattern;
691
692 LogicalResult matchAndRewrite(OpTy op,
693 PatternRewriter &rewriter) const override {
694 auto isPotentiallyNonEmptyShape = [](Value shape) {
695 if (auto extentTensorTy =
696 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
697 if (extentTensorTy.getDimSize(0) == 0)
698 return false;
699 }
700 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
701 if (constShape.getShape().empty())
702 return false;
703 }
704 return true;
705 };
706 auto newOperands = llvm::filter_to_vector<8>(op->getOperands(),
707 isPotentiallyNonEmptyShape);
708
709 // Replace the op with empty shape constant if all operants are reduced to
710 // be empty.
711 if (newOperands.empty()) {
712 rewriter.replaceOpWithNewOp<ConstShapeOp>(
713 op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({}));
714 return success();
715 }
716
717 // Reduce op to equivalent without empty shape operands.
718 if (newOperands.size() < op.getNumOperands()) {
719 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands,
720 op->getAttrs());
721 return success();
722 }
723
724 return failure();
725 }
726};
727
728struct BroadcastForwardSingleOperandPattern
729 : public OpRewritePattern<BroadcastOp> {
730 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
731
732 LogicalResult matchAndRewrite(BroadcastOp op,
733 PatternRewriter &rewriter) const override {
734 if (op.getNumOperands() != 1)
735 return failure();
736 Value replacement = op.getShapes().front();
737
738 // Insert cast if needed.
739 if (replacement.getType() != op.getType()) {
740 auto loc = op.getLoc();
741 if (llvm::isa<ShapeType>(op.getType())) {
742 replacement = FromExtentTensorOp::create(rewriter, loc, replacement);
743 } else {
744 assert(!llvm::isa<ShapeType>(op.getType()) &&
745 !llvm::isa<ShapeType>(replacement.getType()) &&
746 "expect extent tensor cast");
748 tensor::CastOp::create(rewriter, loc, op.getType(), replacement);
749 }
750 }
751
752 rewriter.replaceOp(op, replacement);
753 return success();
754 }
755};
756
757struct BroadcastFoldConstantOperandsPattern
758 : public OpRewritePattern<BroadcastOp> {
759 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
760
761 LogicalResult matchAndRewrite(BroadcastOp op,
762 PatternRewriter &rewriter) const override {
763 SmallVector<int64_t, 8> foldedConstantShape;
764 SmallVector<Value, 8> newShapeOperands;
765 for (Value shape : op.getShapes()) {
766 if (auto constShape = shape.getDefiningOp<ConstShapeOp>()) {
767 SmallVector<int64_t, 8> newFoldedConstantShape;
769 foldedConstantShape,
770 llvm::to_vector<8>(constShape.getShape().getValues<int64_t>()),
771 newFoldedConstantShape)) {
772 foldedConstantShape = newFoldedConstantShape;
773 continue;
774 }
775 }
776 newShapeOperands.push_back(shape);
777 }
778
779 // Need at least two constant operands to fold anything.
780 if (op.getNumOperands() - newShapeOperands.size() < 2)
781 return failure();
782
783 auto foldedConstantOperandsTy = RankedTensorType::get(
784 {static_cast<int64_t>(foldedConstantShape.size())},
785 rewriter.getIndexType());
786 newShapeOperands.push_back(
787 ConstShapeOp::create(rewriter, op.getLoc(), foldedConstantOperandsTy,
788 rewriter.getIndexTensorAttr(foldedConstantShape)));
789 rewriter.replaceOpWithNewOp<BroadcastOp>(op, op.getType(),
790 newShapeOperands);
791 return success();
792 }
793};
794
795template <typename OpTy>
796struct CanonicalizeCastExtentTensorOperandsPattern
797 : public OpRewritePattern<OpTy> {
798 using OpRewritePattern<OpTy>::OpRewritePattern;
799
800 LogicalResult matchAndRewrite(OpTy op,
801 PatternRewriter &rewriter) const override {
802 // Canonicalize operands.
803 bool anyChange = false;
804 auto canonicalizeOperand = [&](Value operand) -> Value {
805 if (auto castOp = operand.getDefiningOp<tensor::CastOp>()) {
806 // Only eliminate the cast if it holds no shape information.
807 bool isInformationLoosingCast =
808 llvm::cast<RankedTensorType>(castOp.getType()).isDynamicDim(0);
809 if (isInformationLoosingCast) {
810 anyChange = true;
811 return castOp.getSource();
812 }
813 }
814 return operand;
815 };
816 auto newOperands =
817 llvm::map_to_vector<8>(op.getOperands(), canonicalizeOperand);
818
819 // Rewrite op if any change required.
820 if (!anyChange)
821 return failure();
822 rewriter.replaceOpWithNewOp<OpTy>(op, op->getResultTypes(), newOperands);
823 return success();
824 }
825};
826
827struct BroadcastConcretizeResultTypePattern
828 : public OpRewritePattern<BroadcastOp> {
829 using OpRewritePattern<BroadcastOp>::OpRewritePattern;
830
831 LogicalResult matchAndRewrite(BroadcastOp op,
832 PatternRewriter &rewriter) const override {
833 // Only concretize dynamic extent tensor result types.
834 auto resultTy = llvm::dyn_cast<RankedTensorType>(op.getType());
835 if (!resultTy || !resultTy.isDynamicDim(0))
836 return failure();
837
838 // Infer resulting shape rank if possible.
839 int64_t maxRank = 0;
840 for (Value shape : op.getShapes()) {
841 if (auto extentTensorTy =
842 llvm::dyn_cast<RankedTensorType>(shape.getType())) {
843 // Cannot infer resulting shape rank if any operand is dynamically
844 // ranked.
845 if (extentTensorTy.isDynamicDim(0))
846 return failure();
847 maxRank = std::max(maxRank, extentTensorTy.getDimSize(0));
848 }
849 }
850
851 auto newOp = BroadcastOp::create(rewriter, op.getLoc(),
853 op.getShapes());
854 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
855 return success();
856 }
857};
858} // namespace
859
860void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
861 MLIRContext *context) {
862 patterns.add<BroadcastConcretizeResultTypePattern,
863 BroadcastFoldConstantOperandsPattern,
864 BroadcastForwardSingleOperandPattern,
865 CanonicalizeCastExtentTensorOperandsPattern<BroadcastOp>,
866 RemoveDuplicateOperandsPattern<BroadcastOp>,
867 RemoveEmptyShapeOperandsPattern<BroadcastOp>>(context);
868}
869
870//===----------------------------------------------------------------------===//
871// ConcatOp
872//===----------------------------------------------------------------------===//
873
874OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
875 if (!adaptor.getLhs() || !adaptor.getRhs())
876 return nullptr;
877 auto lhsShape = llvm::to_vector<6>(
878 llvm::cast<DenseIntElementsAttr>(adaptor.getLhs()).getValues<int64_t>());
879 auto rhsShape = llvm::to_vector<6>(
880 llvm::cast<DenseIntElementsAttr>(adaptor.getRhs()).getValues<int64_t>());
881 SmallVector<int64_t, 6> resultShape;
882 resultShape.append(lhsShape.begin(), lhsShape.end());
883 resultShape.append(rhsShape.begin(), rhsShape.end());
884 Builder builder(getContext());
885 return builder.getIndexTensorAttr(resultShape);
886}
887
888//===----------------------------------------------------------------------===//
889// ConstShapeOp
890//===----------------------------------------------------------------------===//
891
892void ConstShapeOp::print(OpAsmPrinter &p) {
893 p << " ";
894 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"shape"});
895 p << "[";
896 interleaveComma(getShape().getValues<int64_t>(), p);
897 p << "] : ";
898 p.printType(getType());
899}
900
901ParseResult ConstShapeOp::parse(OpAsmParser &parser, OperationState &result) {
902 if (parser.parseOptionalAttrDict(result.attributes))
903 return failure();
904 // We piggy-back on ArrayAttr parsing, though we don't internally store the
905 // shape as an ArrayAttr.
906 // TODO: Implement custom parser and maybe make syntax a bit more concise.
907 Attribute extentsRaw;
908 NamedAttrList dummy;
909 if (parser.parseAttribute(extentsRaw, "dummy", dummy))
910 return failure();
911 auto extentsArray = llvm::dyn_cast<ArrayAttr>(extentsRaw);
912 if (!extentsArray)
913 return failure();
915 for (Attribute extent : extentsArray) {
916 IntegerAttr attr = llvm::dyn_cast<IntegerAttr>(extent);
917 if (!attr)
918 return failure();
919 ints.push_back(attr.getInt());
920 }
921 Builder &builder = parser.getBuilder();
922 result.addAttribute("shape", builder.getIndexTensorAttr(ints));
923 Type resultTy;
924 if (parser.parseColonType(resultTy))
925 return failure();
926 result.types.push_back(resultTy);
927 return success();
928}
929
930OpFoldResult ConstShapeOp::fold(FoldAdaptor) { return getShapeAttr(); }
931
932void ConstShapeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
933 MLIRContext *context) {
934 patterns.add<TensorCastConstShape>(context);
935}
936
937LogicalResult mlir::shape::ConstShapeOp::inferReturnTypes(
938 MLIRContext *context, std::optional<Location> location,
939 ConstShapeOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
940 Builder b(context);
941 const Properties prop = adaptor.getProperties();
942 inferredReturnTypes.assign({RankedTensorType::get(
943 {static_cast<int64_t>(prop.shape.size())}, b.getIndexType())});
944 return success();
945}
946
947bool mlir::shape::ConstShapeOp::isCompatibleReturnTypes(TypeRange l,
948 TypeRange r) {
949 if (l.size() != 1 || r.size() != 1)
950 return false;
951
952 Type lhs = l.front();
953 Type rhs = r.front();
954
955 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
956 // Shape type is compatible with all other valid return types.
957 return true;
958 return lhs == rhs;
959}
960
961//===----------------------------------------------------------------------===//
962// CstrBroadcastableOp
963//===----------------------------------------------------------------------===//
964
965void CstrBroadcastableOp::getCanonicalizationPatterns(
967 // Canonicalization patterns have overlap with the considerations during
968 // folding in case additional shape information is inferred at some point that
969 // does not result in folding.
970 patterns.add<CanonicalizeCastExtentTensorOperandsPattern<CstrBroadcastableOp>,
971 CstrBroadcastableEqOps,
972 RemoveDuplicateOperandsPattern<CstrBroadcastableOp>,
973 RemoveEmptyShapeOperandsPattern<CstrBroadcastableOp>>(context);
974}
975
976// Return true if there is exactly one attribute not representing a scalar
977// broadcast.
979 bool nonScalarSeen = false;
980 for (Attribute a : attributes) {
981 if (!a || llvm::cast<DenseIntElementsAttr>(a).getNumElements() != 0) {
982 if (nonScalarSeen)
983 return false;
984 nonScalarSeen = true;
985 }
986 }
987 return true;
988}
989
990OpFoldResult CstrBroadcastableOp::fold(FoldAdaptor adaptor) {
991 // No broadcasting is needed if all operands but one are scalar.
992 if (hasAtMostSingleNonScalar(adaptor.getShapes()))
993 return BoolAttr::get(getContext(), true);
994
995 if ([&] {
997 for (const auto &operand : adaptor.getShapes()) {
998 if (!operand)
999 return false;
1000 extents.push_back(llvm::to_vector<6>(
1001 llvm::cast<DenseIntElementsAttr>(operand).getValues<int64_t>()));
1002 }
1004 }())
1005 return BoolAttr::get(getContext(), true);
1006
1007 // Lastly, see if folding can be completed based on what constraints are known
1008 // on the input shapes.
1009 if ([&] {
1011 for (auto shapeValue : getShapes()) {
1012 extents.emplace_back();
1013 if (failed(getShapeVec(shapeValue, extents.back())))
1014 return false;
1015 }
1017 }())
1018 return BoolAttr::get(getContext(), true);
1019
1020 // Because a failing witness result here represents an eventual assertion
1021 // failure, we do not replace it with a constant witness.
1022 return nullptr;
1023}
1024
1025LogicalResult CstrBroadcastableOp::verify() {
1026 // Ensure that CstrBroadcastableOp contains at least two operands
1027 if (getNumOperands() < 2)
1028 return emitOpError("required at least 2 input shapes");
1029 return success();
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// CstrEqOp
1034//===----------------------------------------------------------------------===//
1035
1036void CstrEqOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1037 MLIRContext *context) {
1038 // If inputs are equal, return passing witness
1039 patterns.add<CstrEqEqOps>(context);
1040}
1041
1042OpFoldResult CstrEqOp::fold(FoldAdaptor adaptor) {
1043 if (llvm::all_of(adaptor.getShapes(), [&](Attribute a) {
1044 return a && a == adaptor.getShapes().front();
1045 }))
1046 return BoolAttr::get(getContext(), true);
1047
1048 // Because a failing witness result here represents an eventual assertion
1049 // failure, we do not try to replace it with a constant witness. Similarly, we
1050 // cannot if there are any non-const inputs.
1051 return nullptr;
1052}
1053
1054//===----------------------------------------------------------------------===//
1055// ConstSizeOp
1056//===----------------------------------------------------------------------===//
1057
1058void ConstSizeOp::build(OpBuilder &builder, OperationState &result,
1059 int64_t value) {
1060 build(builder, result, builder.getIndexAttr(value));
1061}
1062
1063OpFoldResult ConstSizeOp::fold(FoldAdaptor) { return getValueAttr(); }
1064
1065void ConstSizeOp::getAsmResultNames(
1066 llvm::function_ref<void(Value, StringRef)> setNameFn) {
1067 SmallString<4> buffer;
1068 llvm::raw_svector_ostream os(buffer);
1069 os << "c" << getValue();
1070 setNameFn(getResult(), os.str());
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// ConstWitnessOp
1075//===----------------------------------------------------------------------===//
1076
1077OpFoldResult ConstWitnessOp::fold(FoldAdaptor) { return getPassingAttr(); }
1078
1079//===----------------------------------------------------------------------===//
1080// CstrRequireOp
1081//===----------------------------------------------------------------------===//
1082
1083OpFoldResult CstrRequireOp::fold(FoldAdaptor adaptor) {
1084 return adaptor.getPred();
1085}
1086
1087//===----------------------------------------------------------------------===//
1088// DimOp
1089//===----------------------------------------------------------------------===//
1090
1091std::optional<int64_t> DimOp::getConstantIndex() {
1092 if (auto constSizeOp = getIndex().getDefiningOp<ConstSizeOp>())
1093 return constSizeOp.getValue().getLimitedValue();
1094 if (auto constantOp = getIndex().getDefiningOp<arith::ConstantOp>())
1095 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1096 return std::nullopt;
1097}
1098
1099OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1100 Type valType = getValue().getType();
1101 auto valShapedType = llvm::dyn_cast<ShapedType>(valType);
1102 if (!valShapedType || !valShapedType.hasRank())
1103 return nullptr;
1104 std::optional<int64_t> index = getConstantIndex();
1105 if (!index.has_value())
1106 return nullptr;
1107 if (index.value() < 0 || index.value() >= valShapedType.getRank())
1108 return nullptr;
1109 auto extent = valShapedType.getDimSize(*index);
1110 if (ShapedType::isDynamic(extent))
1111 return nullptr;
1112 return IntegerAttr::get(IndexType::get(getContext()), extent);
1113}
1114
1115LogicalResult mlir::shape::DimOp::inferReturnTypes(
1116 MLIRContext *context, std::optional<Location> location,
1117 DimOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1118 inferredReturnTypes.assign({adaptor.getIndex().getType()});
1119 return success();
1120}
1121
1122bool mlir::shape::DimOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1124}
1125
1126//===----------------------------------------------------------------------===//
1127// DivOp
1128//===----------------------------------------------------------------------===//
1129
1130OpFoldResult DivOp::fold(FoldAdaptor adaptor) {
1131 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1132 if (!lhs)
1133 return nullptr;
1134 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1135 if (!rhs || rhs.getValue().isZero())
1136 return nullptr;
1137
1138 // Division in APInt does not follow floor(lhs, rhs) when the result is
1139 // negative. Rather, APInt rounds toward zero.
1140 APInt quotient, remainder;
1141 APInt::sdivrem(lhs.getValue(), rhs.getValue(), quotient, remainder);
1142 if (quotient.isNegative() && !remainder.isZero()) {
1143 quotient -= 1;
1144 }
1145
1146 Type indexTy = IndexType::get(getContext());
1147 return IntegerAttr::get(indexTy, quotient);
1148}
1149
1150LogicalResult mlir::shape::DivOp::inferReturnTypes(
1151 MLIRContext *context, std::optional<Location> location,
1152 DivOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1153 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1154 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1155 inferredReturnTypes.assign({SizeType::get(context)});
1156 else
1157 inferredReturnTypes.assign({IndexType::get(context)});
1158 return success();
1159}
1160
1161bool mlir::shape::DivOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1162 // SizeType is compatible with IndexType.
1164}
1165
1166LogicalResult DivOp::verify() { return verifySizeOrIndexOp(*this); }
1167
1168//===----------------------------------------------------------------------===//
1169// ShapeEqOp
1170//===----------------------------------------------------------------------===//
1171
1172OpFoldResult ShapeEqOp::fold(FoldAdaptor adaptor) {
1173 bool allSame = true;
1174 if (!adaptor.getShapes().empty() && !adaptor.getShapes().front())
1175 return {};
1176 for (Attribute operand : adaptor.getShapes().drop_front()) {
1177 if (!operand)
1178 return {};
1179 allSame = allSame && operand == adaptor.getShapes().front();
1180 }
1181 return BoolAttr::get(getContext(), allSame);
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// IndexToSizeOp
1186//===----------------------------------------------------------------------===//
1187
1188OpFoldResult IndexToSizeOp::fold(FoldAdaptor adaptor) {
1189 // Constant values of both types, `shape.size` and `index`, are represented as
1190 // `IntegerAttr`s which makes constant folding simple.
1191 if (Attribute arg = adaptor.getArg())
1192 return arg;
1193 return {};
1194}
1195
1196void IndexToSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1197 MLIRContext *context) {
1198 patterns.add<SizeToIndexToSizeCanonicalization>(context);
1199}
1200
1201//===----------------------------------------------------------------------===//
1202// FromExtentsOp
1203//===----------------------------------------------------------------------===//
1204
1205OpFoldResult FromExtentsOp::fold(FoldAdaptor adaptor) {
1206 if (llvm::any_of(adaptor.getExtents(), [](Attribute a) { return !a; }))
1207 return nullptr;
1209 for (auto attr : adaptor.getExtents())
1210 extents.push_back(llvm::cast<IntegerAttr>(attr).getInt());
1211 Builder builder(getContext());
1212 return builder.getIndexTensorAttr(extents);
1213}
1214
1215//===----------------------------------------------------------------------===//
1216// FunctionLibraryOp
1217//===----------------------------------------------------------------------===//
1218
1219void FunctionLibraryOp::build(OpBuilder &builder, OperationState &result,
1220 StringRef name) {
1221 result.attributes.push_back(builder.getNamedAttr(
1223}
1224
1225FuncOp FunctionLibraryOp::getShapeFunction(Operation *op) {
1226 auto attr = llvm::dyn_cast_or_null<FlatSymbolRefAttr>(
1227 getMapping().get(op->getName().getIdentifier()));
1228 if (!attr)
1229 return nullptr;
1230 return lookupSymbol<FuncOp>(attr);
1231}
1232
1233ParseResult FunctionLibraryOp::parse(OpAsmParser &parser,
1235 // Parse the op name.
1236 StringAttr nameAttr;
1238 result.attributes))
1239 return failure();
1240
1241 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
1242 return failure();
1243
1244 auto *bodyRegion = result.addRegion();
1245 if (parser.parseRegion(*bodyRegion))
1246 return failure();
1247
1248 if (parser.parseKeyword("mapping"))
1249 return failure();
1250
1251 DictionaryAttr mappingAttr;
1252 if (parser.parseAttribute(mappingAttr,
1253 parser.getBuilder().getType<NoneType>(), "mapping",
1254 result.attributes))
1255 return failure();
1256 return success();
1257}
1258
1259void FunctionLibraryOp::print(OpAsmPrinter &p) {
1260 p << ' ';
1261 p.printSymbolName(getName());
1263 (*this)->getAttrs(), {mlir::SymbolTable::getSymbolAttrName(), "mapping"});
1264 p << ' ';
1265 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
1266 /*printBlockTerminators=*/false);
1267 p << " mapping ";
1268 p.printAttributeWithoutType(getMappingAttr());
1269}
1270
1271//===----------------------------------------------------------------------===//
1272// FuncOp
1273//===----------------------------------------------------------------------===//
1274
1275FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1277 OpBuilder builder(location->getContext());
1278 OperationState state(location, getOperationName());
1279 FuncOp::build(builder, state, name, type, attrs);
1280 return cast<FuncOp>(Operation::create(state));
1281}
1282FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1284 SmallVector<NamedAttribute, 8> attrRef(attrs);
1285 return create(location, name, type, llvm::ArrayRef(attrRef));
1286}
1287FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
1289 ArrayRef<DictionaryAttr> argAttrs) {
1290 FuncOp func = create(location, name, type, attrs);
1291 func.setAllArgAttrs(argAttrs);
1292 return func;
1293}
1294
1295void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
1296 FunctionType type, ArrayRef<NamedAttribute> attrs,
1297 ArrayRef<DictionaryAttr> argAttrs) {
1298 state.addAttribute(FuncOp::getSymNameAttrName(state.name),
1299 builder.getStringAttr(name));
1300 state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
1301 TypeAttr::get(type));
1302 state.attributes.append(attrs.begin(), attrs.end());
1303 state.addRegion();
1304
1305 if (argAttrs.empty())
1306 return;
1307 assert(type.getNumInputs() == argAttrs.size());
1309 builder, state, argAttrs, /*resultAttrs=*/{},
1310 getArgAttrsAttrName(state.name), getResAttrsAttrName(state.name));
1311}
1312
1313ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
1314 auto buildFuncType =
1315 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
1317 std::string &) { return builder.getFunctionType(argTypes, results); };
1318
1320 parser, result, /*allowVariadic=*/false,
1321 getFunctionTypeAttrName(result.name), buildFuncType,
1322 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
1323}
1324
1325void FuncOp::print(OpAsmPrinter &p) {
1327 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
1328 getArgAttrsAttrName(), getResAttrsAttrName());
1329}
1330
1331//===----------------------------------------------------------------------===//
1332// GetExtentOp
1333//===----------------------------------------------------------------------===//
1334
1335std::optional<int64_t> GetExtentOp::getConstantDim() {
1336 if (auto constSizeOp = getDim().getDefiningOp<ConstSizeOp>())
1337 return constSizeOp.getValue().getLimitedValue();
1338 if (auto constantOp = getDim().getDefiningOp<arith::ConstantOp>())
1339 return llvm::cast<IntegerAttr>(constantOp.getValue()).getInt();
1340 return std::nullopt;
1341}
1342
1343OpFoldResult GetExtentOp::fold(FoldAdaptor adaptor) {
1344 auto elements = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1345 if (!elements)
1346 return nullptr;
1347 std::optional<int64_t> dim = getConstantDim();
1348 if (!dim.has_value())
1349 return nullptr;
1350 if (dim.value() >= elements.getNumElements())
1351 return nullptr;
1352 return elements.getValues<Attribute>()[(uint64_t)dim.value()];
1353}
1354
1355void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape,
1356 int64_t dim) {
1357 auto loc = result.location;
1358 auto dimAttr = builder.getIndexAttr(dim);
1359 if (llvm::isa<ShapeType>(shape.getType())) {
1360 Value dim = ConstSizeOp::create(builder, loc, dimAttr);
1361 build(builder, result, builder.getType<SizeType>(), shape, dim);
1362 } else {
1363 Value dim = arith::ConstantOp::create(builder, loc, builder.getIndexType(),
1364 dimAttr);
1365 build(builder, result, builder.getIndexType(), shape, dim);
1366 }
1367}
1368
1369LogicalResult mlir::shape::GetExtentOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1371 GetExtentOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1372 inferredReturnTypes.assign({IndexType::get(context)});
1373 return success();
1374}
1375
1376bool mlir::shape::GetExtentOp::isCompatibleReturnTypes(TypeRange l,
1377 TypeRange r) {
1378 // SizeType is compatible with IndexType.
1380}
1381
1382LogicalResult GetExtentOp::verify() { return verifySizeOrIndexOp(*this); }
1383
1384//===----------------------------------------------------------------------===//
1385// IsBroadcastableOp
1386//===----------------------------------------------------------------------===//
1387
1388void IsBroadcastableOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1389 MLIRContext *context) {
1390 patterns.add<RemoveDuplicateOperandsPattern<IsBroadcastableOp>>(context);
1391}
1392
1393OpFoldResult IsBroadcastableOp::fold(FoldAdaptor adaptor) {
1394 // Can always broadcast fewer than two shapes.
1395 if (adaptor.getShapes().size() < 2) {
1396 return BoolAttr::get(getContext(), true);
1397 }
1398
1399 return nullptr;
1400}
1401
1402//===----------------------------------------------------------------------===//
1403// MeetOp
1404//===----------------------------------------------------------------------===//
1405
1406LogicalResult mlir::shape::MeetOp::inferReturnTypes(
1407 MLIRContext *context, std::optional<Location> location,
1408 MeetOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1409 if (adaptor.getOperands().empty())
1410 return failure();
1411
1412 auto isShapeType = [](Type arg) {
1413 if (llvm::isa<ShapeType>(arg))
1414 return true;
1415 return isExtentTensorType(arg);
1416 };
1417
1418 ValueRange::type_range types = adaptor.getOperands().getTypes();
1419 Type acc = types.front();
1420 for (auto t : drop_begin(types)) {
1421 Type l = acc, r = t;
1422 if (!llvm::isa<ShapeType, SizeType>(l))
1423 std::swap(l, r);
1424
1425 // Handle sizes, propagate error type if present.
1426 if (llvm::isa<SizeType>(l)) {
1427 if (llvm::isa<SizeType, IndexType>(r))
1428 acc = l;
1429 else
1430 return emitOptionalError(location, "requires all sizes or shapes");
1431 } else if (llvm::isa<IndexType>(l)) {
1432 if (llvm::isa<IndexType>(r))
1433 acc = r;
1434 else
1435 return emitOptionalError(location, "requires all sizes or shapes");
1436 } else if (llvm::isa<ShapeType>(l)) {
1437 // Handle shapes, propagate error type if present.
1438 if (isShapeType(r))
1439 acc = l;
1440 else
1441 return emitOptionalError(location, "requires all sizes or shapes");
1442 } else if (isExtentTensorType(l)) {
1443 auto rank1 = llvm::cast<RankedTensorType>(l).getShape()[0];
1444 auto rank2 = llvm::cast<RankedTensorType>(r).getShape()[0];
1445 if (ShapedType::isDynamic(rank1))
1446 acc = l;
1447 else if (ShapedType::isDynamic(rank2))
1448 acc = r;
1449 else if (rank1 != rank2)
1450 return emitOptionalError(location, "unequal shape cardinality");
1451 else
1452 acc = l;
1453 }
1454 }
1455 inferredReturnTypes.assign({acc});
1456 return success();
1457}
1458
1459bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1460 if (l.size() != 1 || r.size() != 1)
1461 return false;
1462 if (l == r)
1463 return true;
1464
1465 Type lhs = l.front();
1466 Type rhs = r.front();
1467
1468 if (!llvm::isa<ShapeType, SizeType>(lhs))
1469 std::swap(lhs, rhs);
1470
1471 if (llvm::isa<SizeType>(lhs))
1472 return llvm::isa<SizeType, IndexType>(rhs);
1473 if (llvm::isa<ShapeType>(lhs))
1474 return llvm::isa<ShapeType, TensorType>(rhs);
1475
1476 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1477 return true;
1478 return false;
1479}
1480
1481//===----------------------------------------------------------------------===//
1482// RankOp
1483//===----------------------------------------------------------------------===//
1484
1485OpFoldResult shape::RankOp::fold(FoldAdaptor adaptor) {
1486 auto shape = llvm::dyn_cast_if_present<DenseIntElementsAttr>(adaptor.getShape());
1487 if (!shape)
1488 return {};
1489 int64_t rank = shape.getNumElements();
1490 Builder builder(getContext());
1491 return builder.getIndexAttr(rank);
1492}
1493
1494/// Evaluate the `rank` operation for shapes of ranked tensors at compile time.
1495/// Constant folding fails in cases where only the rank is constant, not the
1496/// shape itself.
1497/// This canonicalization matches `shape.rank(shape.shape_of(%ranked_tensor))`.
1498///
1499/// Example:
1500///
1501/// %shape = shape.shape_of %ranked_tensor : tensor<1x2x?xf32>
1502/// %rank = shape.rank %shape
1503///
1504/// becomes
1505///
1506/// %rank = shape.const_size 3
1507
1508namespace {
1509struct RankShapeOfCanonicalizationPattern
1510 : public OpRewritePattern<shape::RankOp> {
1511 using OpRewritePattern<shape::RankOp>::OpRewritePattern;
1512
1513 LogicalResult matchAndRewrite(shape::RankOp op,
1514 PatternRewriter &rewriter) const override {
1515 auto shapeOfOp = op.getShape().getDefiningOp<ShapeOfOp>();
1516 if (!shapeOfOp)
1517 return failure();
1518 auto rankedTensorType =
1519 llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1520 if (!rankedTensorType)
1521 return failure();
1522 int64_t rank = rankedTensorType.getRank();
1523 if (llvm::isa<IndexType>(op.getType())) {
1524 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(op.getOperation(),
1525 rank);
1526 } else if (llvm::isa<shape::SizeType>(op.getType())) {
1527 rewriter.replaceOpWithNewOp<shape::ConstSizeOp>(op.getOperation(), rank);
1528 } else {
1529 return failure();
1530 }
1531 return success();
1532 }
1533};
1534} // namespace
1535
1536void shape::RankOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1537 MLIRContext *context) {
1538 patterns.add<RankShapeOfCanonicalizationPattern>(context);
1539}
1540
1541LogicalResult mlir::shape::RankOp::inferReturnTypes(
1542 MLIRContext *context, std::optional<Location> location,
1543 RankOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1544 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1545 inferredReturnTypes.assign({SizeType::get(context)});
1546 else
1547 inferredReturnTypes.assign({IndexType::get(context)});
1548 return success();
1549}
1550
1551bool mlir::shape::RankOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1552 // SizeType is compatible with IndexType.
1554}
1555
1556LogicalResult shape::RankOp::verify() { return verifySizeOrIndexOp(*this); }
1557
1558//===----------------------------------------------------------------------===//
1559// NumElementsOp
1560//===----------------------------------------------------------------------===//
1561
1562OpFoldResult NumElementsOp::fold(FoldAdaptor adaptor) {
1563
1564 // Fold only when argument constant.
1565 Attribute shape = adaptor.getShape();
1566 if (!shape)
1567 return {};
1568
1569 APInt product(64, 1);
1570 for (auto value : llvm::cast<DenseIntElementsAttr>(shape))
1571 product *= value;
1572 Builder builder(getContext());
1573 return builder.getIndexAttr(product.getLimitedValue());
1574}
1575
1576LogicalResult mlir::shape::NumElementsOp::inferReturnTypes(
1577 MLIRContext *context, std::optional<Location> location,
1578 NumElementsOp::Adaptor adaptor,
1579 SmallVectorImpl<Type> &inferredReturnTypes) {
1580 if (llvm::isa<ShapeType>(adaptor.getShape().getType()))
1581 inferredReturnTypes.assign({SizeType::get(context)});
1582 else
1583 inferredReturnTypes.assign({IndexType::get(context)});
1584 return success();
1585}
1586
1587bool mlir::shape::NumElementsOp::isCompatibleReturnTypes(TypeRange l,
1588 TypeRange r) {
1589 // SizeType is compatible with IndexType.
1591}
1592
1593LogicalResult shape::NumElementsOp::verify() {
1594 return verifySizeOrIndexOp(*this);
1595}
1596
1597//===----------------------------------------------------------------------===//
1598// MaxOp
1599//===----------------------------------------------------------------------===//
1600
1601OpFoldResult MaxOp::fold(FoldAdaptor adaptor) {
1602 // If operands are equal, just propagate one.
1603 if (getLhs() == getRhs())
1604 return getLhs();
1605 return nullptr;
1606}
1607
1608LogicalResult mlir::shape::MaxOp::inferReturnTypes(
1609 MLIRContext *context, std::optional<Location> location,
1610 MaxOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1611 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1612 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1613 else
1614 inferredReturnTypes.assign({SizeType::get(context)});
1615 return success();
1616}
1617
1618bool mlir::shape::MaxOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1619 if (l.size() != 1 || r.size() != 1)
1620 return false;
1621 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1622 return true;
1623 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1624 return true;
1625 return false;
1626}
1627
1628//===----------------------------------------------------------------------===//
1629// MinOp
1630//===----------------------------------------------------------------------===//
1631
1632OpFoldResult MinOp::fold(FoldAdaptor adaptor) {
1633 // If operands are equal, just propagate one.
1634 if (getLhs() == getRhs())
1635 return getLhs();
1636 return nullptr;
1637}
1638
1639LogicalResult mlir::shape::MinOp::inferReturnTypes(
1640 MLIRContext *context, std::optional<Location> location,
1641 MinOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1642 if (adaptor.getLhs().getType() == adaptor.getRhs().getType())
1643 inferredReturnTypes.assign({adaptor.getLhs().getType()});
1644 else
1645 inferredReturnTypes.assign({SizeType::get(context)});
1646 return success();
1647}
1648
1649bool mlir::shape::MinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1650 if (l.size() != 1 || r.size() != 1)
1651 return false;
1652 if (llvm::isa<ShapeType>(l.front()) && llvm::isa<ShapeType>(r.front()))
1653 return true;
1654 if (llvm::isa<SizeType>(l.front()) && llvm::isa<SizeType>(r.front()))
1655 return true;
1656 return false;
1657}
1658
1659//===----------------------------------------------------------------------===//
1660// MulOp
1661//===----------------------------------------------------------------------===//
1662
1663OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1664 auto lhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getLhs());
1665 if (!lhs)
1666 return nullptr;
1667 auto rhs = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getRhs());
1668 if (!rhs)
1669 return nullptr;
1670 APInt folded = lhs.getValue() * rhs.getValue();
1671 Type indexTy = IndexType::get(getContext());
1672 return IntegerAttr::get(indexTy, folded);
1673}
1674
1675LogicalResult mlir::shape::MulOp::inferReturnTypes(
1676 MLIRContext *context, std::optional<Location> location,
1677 MulOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1678 if (llvm::isa<SizeType>(adaptor.getLhs().getType()) ||
1679 llvm::isa<SizeType>(adaptor.getRhs().getType()))
1680 inferredReturnTypes.assign({SizeType::get(context)});
1681 else
1682 inferredReturnTypes.assign({IndexType::get(context)});
1683 return success();
1684}
1685
1686bool mlir::shape::MulOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1687 // SizeType is compatible with IndexType.
1689}
1690
1691LogicalResult shape::MulOp::verify() { return verifySizeOrIndexOp(*this); }
1692
1693//===----------------------------------------------------------------------===//
1694// ShapeOfOp
1695//===----------------------------------------------------------------------===//
1696
1697namespace {
1698/// Replace shape_of(x) where x has a constant shape with a const_shape op.
1699struct ShapeOfOpToConstShapeOp : public OpRewritePattern<shape::ShapeOfOp> {
1700 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1701
1702 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1703 PatternRewriter &rewriter) const override {
1704 auto type = llvm::dyn_cast<ShapedType>(op.getArg().getType());
1705 if (!type || !type.hasStaticShape())
1706 return failure();
1707 Location loc = op.getLoc();
1708 Value constShape =
1709 ConstShapeOp::create(rewriter, loc,
1710 rewriter.getIndexTensorAttr(type.getShape()))
1711 .getResult();
1712 if (constShape.getType() != op.getResult().getType())
1713 constShape = tensor::CastOp::create(rewriter, loc,
1714 op.getResult().getType(), constShape);
1715 rewriter.replaceOp(op, constShape);
1716 return success();
1717 }
1718};
1719
1720// Canonicalize
1721//
1722// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1723// %1 = shape.shape_of %0 : tensor<*xf32> -> tensor<?xindex>
1724//
1725// to
1726//
1727// %0 = tensor.reshape %input(%shape) : (tensor<*xf32>, tensor<?xindex>) -> tensor<*xf32>
1728// %1 = %shape
1729//
1730struct ShapeOfFromReshape : public OpRewritePattern<shape::ShapeOfOp> {
1731 using OpRewritePattern<shape::ShapeOfOp>::OpRewritePattern;
1732
1733 LogicalResult matchAndRewrite(shape::ShapeOfOp op,
1734 PatternRewriter &rewriter) const override {
1735 auto tensorReshapeOp = op.getArg().getDefiningOp<tensor::ReshapeOp>();
1736 if (!tensorReshapeOp)
1737 return rewriter.notifyMatchFailure(op, "producer is not tensor.reshape");
1738 if (!isa<TensorType>(op.getType()))
1739 return rewriter.notifyMatchFailure(op, "result is not a tensor");
1740
1741 // Operand 'shape' of 'tensor.reshape' may now be used as the result of
1742 // 'shape.shape_of'. While its type is guaranteed to be compatible in well-
1743 // formed IR, it may not be identical (dynamically vs statically shaped),
1744 // in which case it needs to be cast first using 'tensor.cast'.
1745 // Additionally, it may not have identical element type (i32 vs index)
1746 // while it has identical shaped type (dynamic vs static), in which case it
1747 // needs to be cast first using 'arith.index_cast'. Note: 'shape.shape_of'
1748 // op result must be shape or extent tensor.
1749 Value shape = tensorReshapeOp.getShape();
1750
1751 auto opTensorTy = cast<RankedTensorType>(op.getType());
1752 auto shapeTensorTy = cast<RankedTensorType>(shape.getType());
1753
1754 if (opTensorTy != shapeTensorTy) {
1755 if (opTensorTy.getElementType() == shapeTensorTy.getElementType())
1756 shape =
1757 tensor::CastOp::create(rewriter, op.getLoc(), opTensorTy, shape);
1758 else if (!isExtentTensorType(shapeTensorTy))
1759 shape = arith::IndexCastOp::create(rewriter, op.getLoc(), opTensorTy,
1760 shape);
1761 }
1762
1763 rewriter.replaceOp(op, shape);
1764 return success();
1765 }
1766};
1767
1768// Canonicalize
1769// ```
1770// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
1771// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
1772// ```
1773// to
1774// ```
1775// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
1776// ```
1777struct ShapeOfCastExtentTensor : public OpRewritePattern<tensor::CastOp> {
1778 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
1779
1780 LogicalResult matchAndRewrite(tensor::CastOp op,
1781 PatternRewriter &rewriter) const override {
1782 auto ty = llvm::dyn_cast<RankedTensorType>(op.getType());
1783 if (!ty || ty.getRank() != 1)
1784 return failure();
1785
1786 auto shapeOfOp = op.getSource().getDefiningOp<ShapeOfOp>();
1787 if (!shapeOfOp)
1788 return failure();
1789
1790 // Argument type must be ranked and must not conflict.
1791 auto argTy = llvm::dyn_cast<RankedTensorType>(shapeOfOp.getArg().getType());
1792 if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
1793 return failure();
1794
1795 rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.getArg());
1796 return success();
1797 }
1798};
1799} // namespace
1800
1801void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1802 MLIRContext *context) {
1803 patterns.add<ShapeOfCastExtentTensor, ShapeOfFromReshape,
1804 ExtractFromShapeOfExtentTensor, ShapeOfOpToConstShapeOp>(
1805 context);
1806}
1807
1808LogicalResult mlir::shape::ShapeOfOp::inferReturnTypes(
1809 MLIRContext *context, std::optional<Location> location,
1810 ShapeOfOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
1811 if (llvm::isa<ValueShapeType>(adaptor.getArg().getType()))
1812 inferredReturnTypes.assign({ShapeType::get(context)});
1813 else {
1814 auto shapedTy = llvm::cast<ShapedType>(adaptor.getArg().getType());
1815 int64_t rank =
1816 shapedTy.hasRank() ? shapedTy.getRank() : ShapedType::kDynamic;
1817 Type indexTy = IndexType::get(context);
1818 Type extentTensorTy = RankedTensorType::get({rank}, indexTy);
1819 inferredReturnTypes.assign({extentTensorTy});
1820 }
1821 return success();
1822}
1823
1824bool mlir::shape::ShapeOfOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
1825 if (l.size() != 1 || r.size() != 1)
1826 return false;
1827 if (l == r)
1828 return true;
1829
1830 Type lhs = l.front();
1831 Type rhs = r.front();
1832
1833 if (!llvm::isa<ShapeType, ShapedType>(lhs) ||
1834 !llvm::isa<ShapeType, ShapedType>(rhs))
1835 return false;
1836
1837 if (llvm::isa<ShapeType>(lhs) || llvm::isa<ShapeType>(rhs))
1838 // Shape type is compatible with all other valid return types.
1839 return true;
1840
1841 if (succeeded(verifyCompatibleShapes({lhs, rhs})))
1842 return true;
1843 return false;
1844}
1845
1846LogicalResult shape::ShapeOfOp::verify() {
1847 return verifyShapeOrExtentTensorOp(*this);
1848}
1849
1850//===----------------------------------------------------------------------===//
1851// SizeToIndexOp
1852//===----------------------------------------------------------------------===//
1853
1854OpFoldResult SizeToIndexOp::fold(FoldAdaptor adaptor) {
1855 // Constant values of both types, `shape.size` and `index`, are represented as
1856 // `IntegerAttr`s which makes constant folding simple.
1857 if (Attribute arg = adaptor.getArg())
1858 return arg;
1859 return OpFoldResult();
1860}
1861
1862void SizeToIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
1863 MLIRContext *context) {
1864 patterns.add<IndexToSizeToIndexCanonicalization>(context);
1865}
1866
1867bool SizeToIndexOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1868 if (inputs.size() != 1 || outputs.size() != 1)
1869 return false;
1870 return llvm::isa<IndexType, SizeType>(inputs[0]) &&
1871 llvm::isa<IndexType>(outputs[0]);
1872}
1873
1874//===----------------------------------------------------------------------===//
1875// YieldOp
1876//===----------------------------------------------------------------------===//
1877
1878LogicalResult shape::YieldOp::verify() {
1879 auto *parentOp = (*this)->getParentOp();
1880 auto results = parentOp->getResults();
1881 auto operands = getOperands();
1882
1883 if (parentOp->getNumResults() != getNumOperands())
1884 return emitOpError() << "number of operands does not match number of "
1885 "results of its parent";
1886 for (auto e : llvm::zip(results, operands))
1887 if (std::get<0>(e).getType() != std::get<1>(e).getType())
1888 return emitOpError() << "types mismatch between yield op and its parent";
1889
1890 return success();
1891}
1892
1893//===----------------------------------------------------------------------===//
1894// SplitAtOp
1895//===----------------------------------------------------------------------===//
1896
1897LogicalResult SplitAtOp::fold(FoldAdaptor adaptor,
1899 if (!adaptor.getOperand() || !adaptor.getIndex())
1900 return failure();
1901 auto shapeVec = llvm::to_vector<6>(
1902 llvm::cast<DenseIntElementsAttr>(adaptor.getOperand()).getValues<int64_t>());
1903 auto shape = llvm::ArrayRef(shapeVec);
1904 auto splitPoint = llvm::cast<IntegerAttr>(adaptor.getIndex()).getInt();
1905 // Verify that the split point is in the correct range.
1906 // TODO: Constant fold to an "error".
1907 int64_t rank = shape.size();
1908 if (-rank > splitPoint || splitPoint > rank)
1909 return failure();
1910 if (splitPoint < 0)
1911 splitPoint += shape.size();
1912 Builder builder(adaptor.getOperand().getContext());
1913 results.push_back(builder.getIndexTensorAttr(shape.take_front(splitPoint)));
1914 results.push_back(builder.getIndexTensorAttr(shape.drop_front(splitPoint)));
1915 return success();
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// ToExtentTensorOp
1920//===----------------------------------------------------------------------===//
1921
1922OpFoldResult ToExtentTensorOp::fold(FoldAdaptor adaptor) {
1923 if (!adaptor.getInput())
1924 return OpFoldResult();
1925 Builder builder(getContext());
1926 auto shape = llvm::to_vector<6>(
1927 llvm::cast<DenseIntElementsAttr>(adaptor.getInput()).getValues<int64_t>());
1928 auto type = RankedTensorType::get({static_cast<int64_t>(shape.size())},
1929 builder.getIndexType());
1930 return DenseIntElementsAttr::get(type, shape);
1931}
1932
1933bool ToExtentTensorOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1934 if (inputs.size() != 1 || outputs.size() != 1)
1935 return false;
1936 if (auto inputTensor = llvm::dyn_cast<RankedTensorType>(inputs[0])) {
1937 if (!llvm::isa<IndexType>(inputTensor.getElementType()) ||
1938 inputTensor.getRank() != 1)
1939 return false;
1940 } else if (!llvm::isa<ShapeType>(inputs[0])) {
1941 return false;
1942 }
1943
1944 TensorType outputTensor = llvm::dyn_cast<TensorType>(outputs[0]);
1945 return outputTensor && llvm::isa<IndexType>(outputTensor.getElementType());
1946}
1947
1948//===----------------------------------------------------------------------===//
1949// ReduceOp
1950//===----------------------------------------------------------------------===//
1951
1952void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
1953 ValueRange initVals) {
1954 OpBuilder::InsertionGuard g(builder);
1955 result.addOperands(shape);
1956 result.addOperands(initVals);
1957
1958 Region *bodyRegion = result.addRegion();
1959 Block *bodyBlock = builder.createBlock(
1960 bodyRegion, /*insertPt=*/{}, builder.getIndexType(), result.location);
1961
1962 Type elementType;
1963 if (auto tensorType = llvm::dyn_cast<TensorType>(shape.getType()))
1964 elementType = tensorType.getElementType();
1965 else
1966 elementType = SizeType::get(builder.getContext());
1967 bodyBlock->addArgument(elementType, shape.getLoc());
1968
1969 for (Value initVal : initVals) {
1970 bodyBlock->addArgument(initVal.getType(), initVal.getLoc());
1971 result.addTypes(initVal.getType());
1972 }
1973}
1974
1975LogicalResult ReduceOp::verify() {
1976 // Verify block arg types.
1977 Block &block = getRegion().front();
1978
1979 // The block takes index, extent, and aggregated values as arguments.
1980 auto blockArgsCount = getInitVals().size() + 2;
1981 if (block.getNumArguments() != blockArgsCount)
1982 return emitOpError() << "ReduceOp body is expected to have "
1983 << blockArgsCount << " arguments";
1984
1985 // The first block argument is the index and must always be of type `index`.
1986 if (!llvm::isa<IndexType>(block.getArgument(0).getType()))
1987 return emitOpError(
1988 "argument 0 of ReduceOp body is expected to be of IndexType");
1989
1990 // The second block argument is the extent and must be of type `size` or
1991 // `index`, depending on whether the reduce operation is applied to a shape or
1992 // to an extent tensor.
1993 Type extentTy = block.getArgument(1).getType();
1994 if (llvm::isa<ShapeType>(getShape().getType())) {
1995 if (!llvm::isa<SizeType>(extentTy))
1996 return emitOpError("argument 1 of ReduceOp body is expected to be of "
1997 "SizeType if the ReduceOp operates on a ShapeType");
1998 } else {
1999 if (!llvm::isa<IndexType>(extentTy))
2000 return emitOpError(
2001 "argument 1 of ReduceOp body is expected to be of IndexType if the "
2002 "ReduceOp operates on an extent tensor");
2003 }
2004
2005 for (const auto &type : llvm::enumerate(getInitVals()))
2006 if (block.getArgument(type.index() + 2).getType() != type.value().getType())
2007 return emitOpError() << "type mismatch between argument "
2008 << type.index() + 2
2009 << " of ReduceOp body and initial value "
2010 << type.index();
2011 return success();
2012}
2013
2014ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
2015 // Parse operands.
2017 Type shapeOrExtentTensorType;
2018 if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1,
2020 parser.parseColonType(shapeOrExtentTensorType) ||
2021 parser.parseOptionalArrowTypeList(result.types))
2022 return failure();
2023
2024 // Resolve operands.
2025 auto initVals = llvm::ArrayRef(operands).drop_front();
2026 if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType,
2027 result.operands) ||
2028 parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
2029 result.operands))
2030 return failure();
2031
2032 // Parse the body.
2033 Region *body = result.addRegion();
2034 if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{}))
2035 return failure();
2036
2037 // Parse attributes.
2038 if (parser.parseOptionalAttrDict(result.attributes))
2039 return failure();
2040
2041 return success();
2042}
2043
2044void ReduceOp::print(OpAsmPrinter &p) {
2045 p << '(' << getShape() << ", " << getInitVals()
2046 << ") : " << getShape().getType();
2047 p.printOptionalArrowTypeList(getResultTypes());
2048 p << ' ';
2049 p.printRegion(getRegion());
2050 p.printOptionalAttrDict((*this)->getAttrs());
2051}
2052
2053#define GET_OP_CLASSES
2054#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
2055
2056#define GET_TYPEDEF_CLASSES
2057#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool isErrorPropagationPossible(TypeRange operandTypes)
Definition Shape.cpp:66
static bool hasAtMostSingleNonScalar(ArrayRef< Attribute > attributes)
Definition Shape.cpp:978
static LogicalResult verifyShapeOrExtentTensorOp(Operation *op)
Definition Shape.cpp:83
static bool eachHasOnlyOneOfTypes(TypeRange typeRange)
Definition Shape.cpp:96
static LogicalResult verifySizeOrIndexOp(Operation *op)
Definition Shape.cpp:71
static int64_t product(ArrayRef< int64_t > vals)
lhs
static bool isLegalToInline(InlinerInterface &interface, Region *src, Region *insertRegion, bool shouldCloneInlinedRegion, IRMapping &valueMapping)
Utility to check that all of the operations within 'src' can be inlined.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
ParseResult parseSymbolName(StringAttr &result)
Parse an -identifier and store it (without the '@' symbol) in a string attribute.
@ Paren
Parens surrounding zero or more operands.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalAttrDictWithKeyword(NamedAttrList &result)=0
Parse a named dictionary into 'result' if the attributes keyword is present.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
virtual void printType(Type type)
virtual void printSymbolName(StringRef symbolRef)
Print the given string as a symbol reference, i.e.
void printOptionalArrowTypeList(TypeRange &&types)
Print an optional arrow followed by a type list.
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Return the context this attribute belongs to.
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:158
static BoolAttr get(MLIRContext *context, bool value)
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
FunctionType getFunctionType(TypeRange inputs, TypeRange results)
Definition Builders.cpp:76
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
StringAttr getStringAttr(const Twine &bytes)
Definition Builders.cpp:262
DenseIntElementsAttr getIndexTensorAttr(ArrayRef< int64_t > values)
Definition Builders.cpp:193
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
NamedAttribute getNamedAttr(StringRef name, Attribute val)
Definition Builders.cpp:94
An attribute that represents a reference to a dense integer vector or tensor object.
static DenseIntElementsAttr get(const ShapedType &type, Arg &&arg)
Get an instance of a DenseIntElementsAttr with the given arguments.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
NamedAttrList is array of NamedAttributes that tracks whether it is sorted and does some basic work t...
void append(StringRef name, Attribute attr)
Add an attribute with the specified name.
NamedAttribute represents a combination of a name and an Attribute value.
Definition Attributes.h:164
StringAttr getName() const
Return the name of the attribute.
Attribute getValue() const
Return the value of the attribute.
Definition Attributes.h:179
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDictWithKeyword(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary prefixed with 'attribute...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block::iterator getInsertionPoint() const
Returns the current insertion point of the builder.
Definition Builders.h:445
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void setInsertionPointToEnd(Block *block)
Sets the insertion point to the end of the specified block.
Definition Builders.h:436
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
Definition Builders.h:442
This class represents a single result from folding an operation.
A trait used to provide symbol table functionalities to a region operation.
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
iterator_range< dialect_attr_iterator > dialect_attr_range
Definition Operation.h:634
static Operation * create(Location location, OperationName name, TypeRange resultTypes, ValueRange operands, NamedAttrList &&attributes, OpaqueProperties properties, BlockRange successors, unsigned numRegions)
Create a new Operation with the specific fields.
Definition Operation.cpp:67
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OperationName getName()
The name of an operation is the key identifier for it.
Definition Operation.h:119
operand_type_range getOperandTypes()
Definition Operation.h:397
result_type_range getResultTypes()
Definition Operation.h:428
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block * splitBlock(Block *block, Block::iterator before)
Split the operations starting at "before" (inclusive) out of the given block into a new block,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StringRef getSymbolAttrName()
Return the name of the attribute used for symbol names.
Definition SymbolTable.h:76
static Operation * lookupSymbolIn(Operation *op, StringAttr symbol)
Returns the operation registered with the given symbol name with the regions of 'symbolTableOp'.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
Type getElementType() const
Returns the element type of this tensor type.
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
ValueTypeRange< ValueRange > type_range
Definition ValueRange.h:418
Type front()
Return first type in the range.
Definition TypeRange.h:152
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A named class for passing around the variadic flag.
bool staticallyKnownBroadcastable(ArrayRef< SmallVector< int64_t, 6 > > shapes)
Returns true if a broadcast between n shapes is guaranteed to be successful and not result in an erro...
Definition Traits.cpp:24
bool getBroadcastedShape(ArrayRef< int64_t > shape1, ArrayRef< int64_t > shape2, SmallVectorImpl< int64_t > &resultShape)
Returns true and sets resultShape to the broadcasted shape from the two given shapes if they are broa...
Definition Traits.cpp:59
void addArgAndResultAttrs(Builder &builder, OperationState &result, ArrayRef< DictionaryAttr > argAttrs, ArrayRef< DictionaryAttr > resultAttrs, StringAttr argAttrsName, StringAttr resAttrsName)
Adds argument and result attributes, provided as argAttrs and resultAttrs arguments,...
void printFunctionOp(OpAsmPrinter &p, FunctionOpInterface op, bool isVariadic, StringRef typeAttrName, StringAttr argAttrsName, StringAttr resAttrsName)
Printer implementation for function-like operations.
ParseResult parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, StringAttr typeAttrName, FuncTypeBuilder funcTypeBuilder, StringAttr argAttrsName, StringAttr resAttrsName)
Parser implementation for function-like operations.
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
bool isExtentTensorType(Type)
Definition Shape.cpp:44
LogicalResult getShapeVec(Value input, SmallVectorImpl< int64_t > &shapeValues)
Definition Shape.cpp:49
RankedTensorType getExtentTensorType(MLIRContext *ctx, int64_t rank=ShapedType::kDynamic)
Alias type for extent tensors.
Definition Shape.cpp:40
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
LogicalResult verifyCompatibleShapes(TypeRange types1, TypeRange types2)
Returns success if the given two arrays have the same number of elements and each pair wise entries h...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, Type resultType, CalculationT &&calculate)
LogicalResult emitOptionalError(std::optional< Location > loc, Args &&...args)
Overloads of the above emission functions that take an optionally null location.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:120
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:123
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addAttribute(StringRef name, Attribute attr)
Add an attribute with the specified name.
Region * addRegion()
Create a region that should be attached to the operation.