MLIR 23.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28
29using namespace mlir;
30using namespace mlir::memref;
31
32/// Materialize a single constant operation from a given attribute value with
33/// the desired resultant type.
34Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
35 Attribute value, Type type,
36 Location loc) {
37 return arith::ConstantOp::materialize(builder, value, type, loc);
38}
39
40//===----------------------------------------------------------------------===//
41// Common canonicalization pattern support logic
42//===----------------------------------------------------------------------===//
43
44/// This is a common class used for patterns of the form
45/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
46/// into the root operation directly.
47LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
48 bool folded = false;
49 for (OpOperand &operand : op->getOpOperands()) {
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
54 folded = true;
55 }
56 }
57 return success(folded);
58}
59
60/// Return an unranked/ranked tensor type for the given unranked/ranked memref
61/// type.
63 if (auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(memref.getShape(), memref.getElementType());
65 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(memref.getElementType());
67 return NoneType::get(type.getContext());
68}
69
71 int64_t dim) {
72 auto memrefType = llvm::cast<MemRefType>(value.getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.createOrFold<memref::DimOp>(loc, value, dim);
75
76 return builder.getIndexAttr(memrefType.getDimSize(dim));
77}
78
80 Location loc, Value value) {
81 auto memrefType = llvm::cast<MemRefType>(value.getType());
83 for (int64_t i = 0; i < memrefType.getRank(); ++i)
84 result.push_back(getMixedSize(builder, loc, value, i));
85 return result;
86}
87
88//===----------------------------------------------------------------------===//
89// Utility functions for propagating static information
90//===----------------------------------------------------------------------===//
91
92/// Helper function that sets values[i] to constValues[i] if the latter is a
93/// static value, as indicated by ShapedType::kDynamic.
94///
95/// If constValues[i] is dynamic, tries to extract a constant value from
96/// value[i] to allow for additional folding opportunities. Also convertes all
97/// existing attributes to index attributes. (They may be i64 attributes.)
99 ArrayRef<int64_t> constValues) {
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
103 Builder builder(values[i].getContext());
104 if (ShapedType::isStatic(cstVal)) {
105 // Constant value is known, use it directly.
106 values[i] = builder.getIndexAttr(cstVal);
107 continue;
108 }
109 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
110 // Try to extract a constant or convert an existing to index.
111 values[i] = builder.getIndexAttr(*cst);
112 }
113 }
114}
115
116/// Helper function to retrieve a lossless memory-space cast, and the
117/// corresponding new result memref type.
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
122
123 // Bail if the cast is not lossless.
124 if (!castOp)
125 return {};
126
127 // Transform the source and target type of `castOp` to have the same metadata
128 // as `resultTy`. Bail if not possible.
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
131 if (failed(srcTy))
132 return {};
133
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
136 if (failed(tgtTy))
137 return {};
138
139 // Check if this is a valid memory-space cast.
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return {};
142
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
144}
145
146/// Implementation of `bubbleDownCasts` method for memref operations that
147/// return a single memref result.
148template <typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
151 OpOperand &src) {
152 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
153 // Bail if we cannot cast.
154 if (!castOp)
155 return failure();
156
157 // Create the new operands.
158 SmallVector<Value> operands;
159 llvm::append_range(operands, op->getOperands());
160 operands[src.getOperandNumber()] = castOp.getSourcePtr();
161
162 // Create the new op and results.
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166
167 // Insert a memory-space cast to the original memory space of the op.
168 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 builder, tgtTy,
170 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
171 return std::optional<SmallVector<Value>>(
172 SmallVector<Value>({result.getTargetPtr()}));
173}
174
175//===----------------------------------------------------------------------===//
176// AllocOp / AllocaOp
177//===----------------------------------------------------------------------===//
178
179void AllocOp::getAsmResultNames(
180 function_ref<void(Value, StringRef)> setNameFn) {
181 setNameFn(getResult(), "alloc");
182}
183
184void AllocaOp::getAsmResultNames(
185 function_ref<void(Value, StringRef)> setNameFn) {
186 setNameFn(getResult(), "alloca");
187}
188
189template <typename AllocLikeOp>
190static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
194 if (!memRefType)
195 return op.emitOpError("result must be a memref");
196
197 if (failed(verifyDynamicDimensionCount(op, memRefType, op.getDynamicSizes())))
198 return failure();
199
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError("symbol operand count does not equal memref symbol "
205 "count: expected ")
206 << numSymbols << ", got " << op.getSymbolOperands().size();
207
208 return success();
209}
210
211LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
212
213LogicalResult AllocaOp::verify() {
214 // An alloca op needs to have an ancestor with an allocation scope trait.
215 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
216 return emitOpError(
217 "requires an ancestor op with AutomaticAllocationScope trait");
218
219 return verifyAllocLikeOp(*this);
220}
221
222namespace {
223/// Fold constant dimensions into an alloc like operation.
224template <typename AllocLikeOp>
225struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
227
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter) const override {
230 // Check to see if any dimensions operands are constants. If so, we can
231 // substitute and drop them.
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
233 APInt constSizeArg;
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return false;
236 return constSizeArg.isNonNegative();
237 }))
238 return failure();
239
240 auto memrefType = alloc.getType();
241
242 // Ok, we have one or more constant operands. Collect the non-constant ones
243 // and keep track of the resultant memref type to build.
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
247
248 unsigned dynamicDimPos = 0;
249 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
251 // If this is already static dimension, keep it.
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
254 continue;
255 }
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
257 APInt constSizeArg;
258 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
259 constSizeArg.isNonNegative()) {
260 // Dynamic shape dimension will be folded.
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
262 } else {
263 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
266 }
267 dynamicDimPos++;
268 }
269
270 // Create new memref type (which will have fewer dynamic dimensions).
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
274
275 // Create and insert the alloc op for the new memref.
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
279 // Insert a cast so we have the same type as the old alloc.
280 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
281 return success();
282 }
283};
284
285/// Fold alloc operations with no users or only store and dealloc uses.
286template <typename T>
287struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 using OpRewritePattern<T>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter) const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
296 }))
297 return failure();
298
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
300 rewriter.eraseOp(user);
301
302 rewriter.eraseOp(alloc);
303 return success();
304 }
305};
306} // namespace
307
308void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
311}
312
313void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
314 MLIRContext *context) {
315 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
316 context);
317}
318
319//===----------------------------------------------------------------------===//
320// ReallocOp
321//===----------------------------------------------------------------------===//
322
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
325 MemRefType resultType = getType();
326
327 // The source memref should have identity layout (or none).
328 if (!sourceType.getLayout().isIdentity())
329 return emitError("unsupported layout for source memref type ")
330 << sourceType;
331
332 // The result memref should have identity layout (or none).
333 if (!resultType.getLayout().isIdentity())
334 return emitError("unsupported layout for result memref type ")
335 << resultType;
336
337 // The source memref and the result memref should be in the same memory space.
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError("different memory spaces specified for source memref "
340 "type ")
341 << sourceType << " and result memref type " << resultType;
342
343 // The source memref and the result memref should have the same element type.
344 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
345 "result")))
346 return failure();
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor::parent());
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) {
416 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
417}
418
419/// Given an operation, return whether this op is guaranteed to
420/// allocate an AutomaticAllocationScopeResource
422 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
423 if (!interface)
424 return false;
425 for (auto res : op->getResults()) {
426 if (auto effect =
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
430 return true;
431 }
432 }
433 return false;
434}
435
436/// Given an operation, return whether this op itself could
437/// allocate an AutomaticAllocationScopeResource. Note that
438/// this will not check whether an operation contained within
439/// the op can allocate.
441 // This op itself doesn't create a stack allocation,
442 // the inner allocation should be handled separately.
444 return false;
445 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
446 if (!interface)
447 return true;
448 for (auto res : op->getResults()) {
449 if (auto effect =
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
453 return true;
454 }
455 }
456 return false;
457}
458
459/// Return whether this op is the last non terminating op
460/// in a region. That is to say, it is in a one-block region
461/// and is only followed by a terminator. This prevents
462/// extending the lifetime of allocations.
464 return op->getBlock()->mightHaveTerminator() &&
465 op->getNextNode() == op->getBlock()->getTerminator() &&
467}
468
469/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
470/// or it contains no allocation.
471struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
472 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
473
474 LogicalResult matchAndRewrite(AllocaScopeOp op,
475 PatternRewriter &rewriter) const override {
476 bool hasPotentialAlloca =
477 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
478 if (alloc == op)
479 return WalkResult::advance();
481 return WalkResult::interrupt();
482 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
483 return WalkResult::skip();
484 return WalkResult::advance();
485 }).wasInterrupted();
486
487 // If this contains no potential allocation, it is always legal to
488 // inline. Otherwise, consider two conditions:
489 if (hasPotentialAlloca) {
490 // If the parent isn't an allocation scope, or we are not the last
491 // non-terminator op in the parent, we will extend the lifetime.
492 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return failure();
495 return failure();
496 }
497
498 Block *block = &op.getRegion().front();
499 Operation *terminator = block->getTerminator();
500 ValueRange results = terminator->getOperands();
501 rewriter.inlineBlockBefore(block, op);
502 rewriter.replaceOp(op, results);
503 rewriter.eraseOp(terminator);
504 return success();
505 }
506};
507
508/// Move allocations into an allocation scope, if it is legal to
509/// move them (e.g. their operands are available at the location
510/// the op would be moved to).
511struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
512 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
513
514 LogicalResult matchAndRewrite(AllocaScopeOp op,
515 PatternRewriter &rewriter) const override {
516
517 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
518 return failure();
519
520 Operation *lastParentWithoutScope = op->getParentOp();
521
522 if (!lastParentWithoutScope ||
523 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
524 return failure();
525
526 // Only apply to if this is this last non-terminator
527 // op in the block (lest lifetime be extended) of a one
528 // block region
529 if (!lastNonTerminatorInRegion(op) ||
530 !lastNonTerminatorInRegion(lastParentWithoutScope))
531 return failure();
532
533 while (!lastParentWithoutScope->getParentOp()
535 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
536 if (!lastParentWithoutScope ||
537 !lastNonTerminatorInRegion(lastParentWithoutScope))
538 return failure();
539 }
540 assert(lastParentWithoutScope->getParentOp()
542
543 Region *containingRegion = nullptr;
544 for (auto &r : lastParentWithoutScope->getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion == nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
549 }
550 }
551 assert(containingRegion && "op must be contained in a region");
552
554 op->walk([&](Operation *alloc) {
556 return WalkResult::skip();
557
558 // If any operand is not defined before the location of
559 // lastParentWithoutScope (i.e. where we would hoist to), skip.
560 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
561 return containingRegion->isAncestor(v.getParentRegion());
562 }))
563 return WalkResult::skip();
564 toHoist.push_back(alloc);
565 return WalkResult::advance();
566 });
567
568 if (toHoist.empty())
569 return failure();
570 rewriter.setInsertionPoint(lastParentWithoutScope);
571 for (auto *op : toHoist) {
572 auto *cloned = rewriter.clone(*op);
573 rewriter.replaceOp(op, cloned->getResults());
574 }
575 return success();
576 }
577};
578
579void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
580 MLIRContext *context) {
581 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
582}
583
584//===----------------------------------------------------------------------===//
585// AssumeAlignmentOp
586//===----------------------------------------------------------------------===//
587
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError("alignment must be power of 2");
591 return success();
592}
593
594void AssumeAlignmentOp::getAsmResultNames(
595 function_ref<void(Value, StringRef)> setNameFn) {
596 setNameFn(getResult(), "assume_align");
597}
598
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
601 if (!source)
602 return {};
603 if (source.getAlignment() != getAlignment())
604 return {};
605 return getMemref();
606}
607
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
610 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
611}
612
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
614 int resultIndex,
615 int dim) {
616 assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
618}
619
620//===----------------------------------------------------------------------===//
621// DistinctObjectsOp
622//===----------------------------------------------------------------------===//
623
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError("operand types and result types must match");
627
628 if (getOperandTypes().empty())
629 return emitOpError("expected at least one operand");
630
631 return success();
632}
633
634LogicalResult DistinctObjectsOp::inferReturnTypes(
635 MLIRContext * /*context*/, std::optional<Location> /*location*/,
636 ValueRange operands, DictionaryAttr /*attributes*/,
637 PropertyRef /*properties*/, RegionRange /*regions*/,
638 SmallVectorImpl<Type> &inferredReturnTypes) {
639 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// CastOp
645//===----------------------------------------------------------------------===//
646
647void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
648 setNameFn(getResult(), "cast");
649}
650
651/// Determines whether MemRef_CastOp casts to a more dynamic version of the
652/// source memref. This is useful to fold a memref.cast into a consuming op
653/// and implement canonicalization patterns for ops in different dialects that
654/// may consume the results of memref.cast operations. Such foldable memref.cast
655/// operations are typically inserted as `view` and `subview` ops are
656/// canonicalized, to preserve the type compatibility of their uses.
657///
658/// Returns true when all conditions are met:
659/// 1. source and result are ranked memrefs with strided semantics and same
660/// element type and rank.
661/// 2. each of the source's size, offset or stride has more static information
662/// than the corresponding result's size, offset or stride.
663///
664/// Example 1:
665/// ```mlir
666/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
667/// %2 = consumer %1 ... : memref<?x?xf32> ...
668/// ```
669///
670/// may fold into:
671///
672/// ```mlir
673/// %2 = consumer %0 ... : memref<8x16xf32> ...
674/// ```
675///
676/// Example 2:
677/// ```
678/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
679/// to memref<?x?xf32>
680/// consumer %1 : memref<?x?xf32> ...
681/// ```
682///
683/// may fold into:
684///
685/// ```
686/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
687/// ```
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692
693 // Requires ranked MemRefType.
694 if (!sourceType || !resultType)
695 return false;
696
697 // Requires same elemental type.
698 if (sourceType.getElementType() != resultType.getElementType())
699 return false;
700
701 // Requires same rank.
702 if (sourceType.getRank() != resultType.getRank())
703 return false;
704
705 // Only fold casts between strided memref forms.
706 int64_t sourceOffset, resultOffset;
707 SmallVector<int64_t, 4> sourceStrides, resultStrides;
708 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
710 return false;
711
712 // If cast is towards more static sizes along any dimension, don't fold.
713 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
715 if (ss != st)
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
717 return false;
718 }
719
720 // If cast is towards more static offset along any dimension, don't fold.
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
724 return false;
725
726 // If cast is towards more static strides along any dimension, don't fold.
727 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
729 if (ss != st)
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
731 return false;
732 }
733
734 return true;
735}
736
737bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
738 if (inputs.size() != 1 || outputs.size() != 1)
739 return false;
740 if (inputs == outputs)
741 return true;
742 Type a = inputs.front(), b = outputs.front();
743 auto aT = llvm::dyn_cast<MemRefType>(a);
744 auto bT = llvm::dyn_cast<MemRefType>(b);
745
746 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
747 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
748
749 if (aT && bT) {
750 if (aT.getElementType() != bT.getElementType())
751 return false;
752 if (aT.getLayout() != bT.getLayout()) {
753 int64_t aOffset, bOffset;
754 SmallVector<int64_t, 4> aStrides, bStrides;
755 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
756 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
757 aStrides.size() != bStrides.size())
758 return false;
759
760 // Strides along a dimension/offset are compatible if the value in the
761 // source memref is static and the value in the target memref is the
762 // same. They are also compatible if either one is dynamic (see
763 // description of MemRefCastOp for details).
764 // Note that for dimensions of size 1, the stride can differ.
765 auto checkCompatible = [](int64_t a, int64_t b) {
766 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
767 };
768 if (!checkCompatible(aOffset, bOffset))
769 return false;
770 for (const auto &[index, aStride] : enumerate(aStrides)) {
771 if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
772 continue;
773 if (!checkCompatible(aStride, bStrides[index]))
774 return false;
775 }
776 }
777 if (aT.getMemorySpace() != bT.getMemorySpace())
778 return false;
779
780 // They must have the same rank, and any specified dimensions must match.
781 if (aT.getRank() != bT.getRank())
782 return false;
783
784 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
785 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
786 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
787 aDim != bDim)
788 return false;
789 }
790 return true;
791 } else {
792 if (!aT && !uaT)
793 return false;
794 if (!bT && !ubT)
795 return false;
796 // Unranked to unranked casting is unsupported
797 if (uaT && ubT)
798 return false;
799
800 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
801 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
802 if (aEltType != bEltType)
803 return false;
804
805 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
806 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
807 return aMemSpace == bMemSpace;
808 }
809
810 return false;
811}
812
813OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
814 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
815}
816
817FailureOr<std::optional<SmallVector<Value>>>
818CastOp::bubbleDownCasts(OpBuilder &builder) {
819 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
820}
821
822//===----------------------------------------------------------------------===//
823// CopyOp
824//===----------------------------------------------------------------------===//
825
826namespace {
827
828/// Fold memref.copy(%x, %x).
829struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
830 using OpRewritePattern<CopyOp>::OpRewritePattern;
831
832 LogicalResult matchAndRewrite(CopyOp copyOp,
833 PatternRewriter &rewriter) const override {
834 if (copyOp.getSource() != copyOp.getTarget())
835 return failure();
836
837 rewriter.eraseOp(copyOp);
838 return success();
839 }
840};
841
842struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
843 using OpRewritePattern<CopyOp>::OpRewritePattern;
844
845 static bool isEmptyMemRef(BaseMemRefType type) {
846 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
847 }
848
849 LogicalResult matchAndRewrite(CopyOp copyOp,
850 PatternRewriter &rewriter) const override {
851 if (isEmptyMemRef(copyOp.getSource().getType()) ||
852 isEmptyMemRef(copyOp.getTarget().getType())) {
853 rewriter.eraseOp(copyOp);
854 return success();
855 }
856
857 return failure();
858 }
859};
860} // namespace
861
862void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
863 MLIRContext *context) {
864 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
865}
866
867/// If the source/target of a CopyOp is a CastOp that does not modify the shape
868/// and element type, the cast can be skipped. Such CastOps only cast the layout
869/// of the type.
870static LogicalResult foldCopyOfCast(CopyOp op) {
871 for (OpOperand &operand : op->getOpOperands()) {
872 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
873 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
874 operand.set(castOp.getOperand());
875 return success();
876 }
877 }
878 return failure();
879}
880
881LogicalResult CopyOp::fold(FoldAdaptor adaptor,
882 SmallVectorImpl<OpFoldResult> &results) {
883
884 /// copy(memrefcast) -> copy
885 return foldCopyOfCast(*this);
886}
887
888//===----------------------------------------------------------------------===//
889// DeallocOp
890//===----------------------------------------------------------------------===//
891
892LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
893 SmallVectorImpl<OpFoldResult> &results) {
894 /// dealloc(memrefcast) -> dealloc
895 return foldMemRefCast(*this);
896}
897
898//===----------------------------------------------------------------------===//
899// DimOp
900//===----------------------------------------------------------------------===//
901
902void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
903 setNameFn(getResult(), "dim");
904}
905
906void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
907 int64_t index) {
908 auto loc = result.location;
909 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
910 build(builder, result, source, indexValue);
911}
912
913std::optional<int64_t> DimOp::getConstantIndex() {
915}
916
917Speculation::Speculatability DimOp::getSpeculatability() {
918 auto constantIndex = getConstantIndex();
919 if (!constantIndex)
921
922 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
923 if (!rankedSourceType)
925
926 if (rankedSourceType.getRank() <= constantIndex)
928
930}
931
932void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
933 SetIntLatticeFn setResultRange) {
934 setResultRange(getResult(),
935 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
936}
937
938/// Return a map with key being elements in `vals` and data being number of
939/// occurences of it. Use std::map, since the `vals` here are strides and the
940/// dynamic stride value is the same as the tombstone value for
941/// `DenseMap<int64_t>`.
942static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
943 std::map<int64_t, unsigned> numOccurences;
944 for (auto val : vals)
945 numOccurences[val]++;
946 return numOccurences;
947}
948
949/// Returns the set of source dimensions that are dropped in a rank reduction.
950/// For each result dimension in order, matches the leftmost unmatched source
951/// dimension with the same size. Source dimensions not matched are dropped.
952///
953/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
954/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
955/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
956static FailureOr<llvm::SmallBitVector>
958 MemRefType reducedType,
960 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
961 if (rankReduction <= 0)
962 return llvm::SmallBitVector(originalType.getRank());
963
964 // Build source sizes from subview sizes (one per source dim).
965 SmallVector<int64_t> sourceSizes(originalType.getRank());
966 for (const auto &it : llvm::enumerate(sizes)) {
967 if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
968 sourceSizes[it.index()] = *cst;
969 else
970 sourceSizes[it.index()] = ShapedType::kDynamic;
971 }
972
973 ArrayRef<int64_t> resultSizes = reducedType.getShape();
974 llvm::SmallBitVector usedSourceDims(originalType.getRank());
975 int64_t startJ = 0;
976 for (int64_t resultSize : resultSizes) {
977 bool matched = false;
978 for (int64_t j = startJ; j < originalType.getRank(); ++j) {
979 if (sourceSizes[j] == resultSize) {
980 usedSourceDims.set(j);
981 matched = true;
982 startJ = j + 1;
983 break;
984 }
985 }
986 if (!matched)
987 return failure();
988 }
989
990 llvm::SmallBitVector unusedDims(originalType.getRank());
991 for (int64_t i = 0; i < originalType.getRank(); ++i)
992 if (!usedSourceDims.test(i))
993 unusedDims.set(i);
994 return unusedDims;
995}
996
997/// Returns the set of source dimensions that are dropped in a rank reduction.
998/// A dimension is dropped if its stride is dropped; uses stride occurrence
999/// counting to disambiguate when multiple unit dims exist.
1000///
1001/// Example: memref<1x1x?xf32, strided<[?, 4, 1]>> to memref<1x4xf32,
1002/// strided<[4, 1]>>. Source strides [?, 4, 1], candidate [4, 1]. Dim 0 (stride
1003/// ?) can be dropped; dim 1 (stride 4) must be kept. Source dim 0 is dropped.
1004static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
1005 MemRefType originalType, MemRefType reducedType,
1006 ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
1007 llvm::SmallBitVector unusedDims) {
1008 // Track the number of occurences of the strides in the original type
1009 // and the candidate type. For each unused dim that stride should not be
1010 // present in the candidate type. Note that there could be multiple dimensions
1011 // that have the same size. We dont need to exactly figure out which dim
1012 // corresponds to which stride, we just need to verify that the number of
1013 // reptitions of a stride in the original + number of unused dims with that
1014 // stride == number of repititions of a stride in the candidate.
1015 std::map<int64_t, unsigned> currUnaccountedStrides =
1016 getNumOccurences(originalStrides);
1017 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1018 getNumOccurences(candidateStrides);
1019 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1020 if (!unusedDims.test(dim))
1021 continue;
1022 int64_t originalStride = originalStrides[dim];
1023 if (currUnaccountedStrides[originalStride] >
1024 candidateStridesNumOccurences[originalStride]) {
1025 // This dim can be treated as dropped.
1026 currUnaccountedStrides[originalStride]--;
1027 continue;
1028 }
1029 if (currUnaccountedStrides[originalStride] ==
1030 candidateStridesNumOccurences[originalStride]) {
1031 // The stride for this is not dropped. Keep as is.
1032 unusedDims.reset(dim);
1033 continue;
1034 }
1035 if (currUnaccountedStrides[originalStride] <
1036 candidateStridesNumOccurences[originalStride]) {
1037 // This should never happen. Cant have a stride in the reduced rank type
1038 // that wasnt in the original one.
1039 return failure();
1040 }
1041 }
1042 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1043 originalType.getRank())
1044 return failure();
1045 return unusedDims;
1046}
1047
1048/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
1049/// to be a subset of `originalType` with some `1` entries erased, return the
1050/// set of indices that specifies which of the entries of `originalShape` are
1051/// dropped to obtain `reducedShape`.
1052/// This accounts for cases where there are multiple unit-dims, but only a
1053/// subset of those are dropped. For MemRefTypes these can be disambiguated
1054/// using the strides. If a dimension is dropped the stride must be dropped too.
1055static FailureOr<llvm::SmallBitVector>
1056computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
1057 ArrayRef<OpFoldResult> sizes) {
1058 llvm::SmallBitVector unusedDims(originalType.getRank());
1059 if (originalType.getRank() == reducedType.getRank())
1060 return unusedDims;
1061
1062 for (const auto &dim : llvm::enumerate(sizes))
1063 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1064 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1065 unusedDims.set(dim.index());
1066
1067 // Early exit for the case where the number of unused dims matches the number
1068 // of ranks reduced.
1069 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1070 originalType.getRank())
1071 return unusedDims;
1072
1073 SmallVector<int64_t> originalStrides, candidateStrides;
1074 int64_t originalOffset, candidateOffset;
1075 if (failed(
1076 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1077 failed(
1078 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1079 return failure();
1080
1081 // Try stride-based first when we have meaningful static stride info
1082 // (preserves static strides). Fall back to position-based otherwise.
1083 auto hasNonTrivialStaticStride = [](ArrayRef<int64_t> strides) {
1084 // The innermost stride 1 is trivial for row-major and does not help
1085 // disambiguate.
1086 if (strides.size() <= 1)
1087 return false;
1088 return llvm::any_of(strides.drop_back(),
1089 [](int64_t s) { return !ShapedType::isDynamic(s); });
1090 };
1091 if (hasNonTrivialStaticStride(originalStrides) ||
1092 hasNonTrivialStaticStride(candidateStrides)) {
1093 FailureOr<llvm::SmallBitVector> strideBased =
1094 computeMemRefRankReductionMaskByStrides(originalType, reducedType,
1095 originalStrides,
1096 candidateStrides, unusedDims);
1097 if (succeeded(strideBased))
1098 return *strideBased;
1099 }
1100 return computeMemRefRankReductionMaskByPosition(originalType, reducedType,
1101 sizes);
1102}
1103
1104llvm::SmallBitVector SubViewOp::getDroppedDims() {
1105 MemRefType sourceType = getSourceType();
1106 MemRefType resultType = getType();
1107 FailureOr<llvm::SmallBitVector> unusedDims =
1108 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1109 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1110 return *unusedDims;
1111}
1112
1113OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1114 // All forms of folding require a known index.
1115 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1116 if (!index)
1117 return {};
1118
1119 // Folding for unranked types (UnrankedMemRefType) is not supported.
1120 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1121 if (!memrefType)
1122 return {};
1123
1124 // Out of bound indices produce undefined behavior but are still valid IR.
1125 // Don't choke on them.
1126 int64_t indexVal = index.getInt();
1127 if (indexVal < 0 || indexVal >= memrefType.getRank())
1128 return {};
1129
1130 // Fold if the shape extent along the given index is known.
1131 if (!memrefType.isDynamicDim(index.getInt())) {
1132 Builder builder(getContext());
1133 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1134 }
1135
1136 // The size at the given index is now known to be a dynamic size.
1137 unsigned unsignedIndex = index.getValue().getZExtValue();
1138
1139 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1140 Operation *definingOp = getSource().getDefiningOp();
1141
1142 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1143 return *(alloc.getDynamicSizes().begin() +
1144 memrefType.getDynamicDimIndex(unsignedIndex));
1145
1146 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1147 return *(alloca.getDynamicSizes().begin() +
1148 memrefType.getDynamicDimIndex(unsignedIndex));
1149
1150 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1151 return *(view.getDynamicSizes().begin() +
1152 memrefType.getDynamicDimIndex(unsignedIndex));
1153
1154 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1155 // The result dim is dynamic (the static case was handled above). Dropped
1156 // dims always have static size 1, so dynamic source sizes are never
1157 // dropped and map in order to the dynamic result dims. Find the k-th
1158 // dynamic source size, where k is the dynamic dim index of the result dim.
1159 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1160 unsigned dynamicIdx = 0;
1161 for (OpFoldResult size : subview.getMixedSizes()) {
1162 if (llvm::isa<Attribute>(size))
1163 continue;
1164 if (dynamicIdx == dynamicResultDimIdx)
1165 return size;
1166 dynamicIdx++;
1167 }
1168 return {};
1169 }
1170
1171 // dim(memrefcast) -> dim
1172 if (succeeded(foldMemRefCast(*this)))
1173 return getResult();
1174
1175 return {};
1176}
1177
1178namespace {
1179/// Fold dim of a memref reshape operation to a load into the reshape's shape
1180/// operand.
1181struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1182 using OpRewritePattern<DimOp>::OpRewritePattern;
1183
1184 LogicalResult matchAndRewrite(DimOp dim,
1185 PatternRewriter &rewriter) const override {
1186 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1187
1188 if (!reshape)
1189 return rewriter.notifyMatchFailure(
1190 dim, "Dim op is not defined by a reshape op.");
1191
1192 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1193 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1194 // cheaply check that either of the following conditions hold:
1195 // 1. dim.getIndex() is defined in the same block as reshape but before
1196 // reshape.
1197 // 2. dim.getIndex() is defined in a parent block of
1198 // reshape.
1199
1200 // Check condition 1
1201 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1202 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1203 if (reshape->isBeforeInBlock(definingOp)) {
1204 return rewriter.notifyMatchFailure(
1205 dim,
1206 "dim.getIndex is not defined before reshape in the same block.");
1207 }
1208 } // else dim.getIndex is a block argument to reshape->getBlock and
1209 // dominates reshape
1210 } // Check condition 2
1211 else if (dim->getBlock() != reshape->getBlock() &&
1212 !dim.getIndex().getParentRegion()->isProperAncestor(
1213 reshape->getParentRegion())) {
1214 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1215 // already know dim.getIndex() dominates reshape without calling
1216 // `isProperAncestor`
1217 return rewriter.notifyMatchFailure(
1218 dim, "dim.getIndex does not dominate reshape.");
1219 }
1220
1221 // Place the load directly after the reshape to ensure that the shape memref
1222 // was not mutated.
1223 rewriter.setInsertionPointAfter(reshape);
1224 Location loc = dim.getLoc();
1225 Value load =
1226 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1227 if (load.getType() != dim.getType())
1228 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1229 rewriter.replaceOp(dim, load);
1230 return success();
1231 }
1232};
1233
1234} // namespace
1235
1236void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1237 MLIRContext *context) {
1238 results.add<DimOfMemRefReshape>(context);
1239}
1240
1241// ---------------------------------------------------------------------------
1242// DmaStartOp
1243// ---------------------------------------------------------------------------
1244
1245void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1246 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1247 ValueRange destIndices, Value numElements,
1248 Value tagMemRef, ValueRange tagIndices, Value stride,
1249 Value elementsPerStride) {
1250 result.addOperands(srcMemRef);
1251 result.addOperands(srcIndices);
1252 result.addOperands(destMemRef);
1253 result.addOperands(destIndices);
1254 result.addOperands({numElements, tagMemRef});
1255 result.addOperands(tagIndices);
1256 if (stride)
1257 result.addOperands({stride, elementsPerStride});
1258}
1259
1260void DmaStartOp::print(OpAsmPrinter &p) {
1261 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1262 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1263 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1264 if (isStrided())
1265 p << ", " << getStride() << ", " << getNumElementsPerStride();
1266
1267 p.printOptionalAttrDict((*this)->getAttrs());
1268 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1269 << ", " << getTagMemRef().getType();
1270}
1271
1272// Parse DmaStartOp.
1273// Ex:
1274// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1275// %tag[%index], %stride, %num_elt_per_stride :
1276// : memref<3076 x f32, 0>,
1277// memref<1024 x f32, 2>,
1278// memref<1 x i32>
1279//
1280ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1281 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1283 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1284 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1285 OpAsmParser::UnresolvedOperand numElementsInfo;
1286 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1287 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1288 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1289
1290 SmallVector<Type, 3> types;
1291 auto indexType = parser.getBuilder().getIndexType();
1292
1293 // Parse and resolve the following list of operands:
1294 // *) source memref followed by its indices (in square brackets).
1295 // *) destination memref followed by its indices (in square brackets).
1296 // *) dma size in KiB.
1297 if (parser.parseOperand(srcMemRefInfo) ||
1298 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1299 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1300 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1301 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1302 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1303 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1304 return failure();
1305
1306 // Parse optional stride and elements per stride.
1307 if (parser.parseTrailingOperandList(strideInfo))
1308 return failure();
1309
1310 bool isStrided = strideInfo.size() == 2;
1311 if (!strideInfo.empty() && !isStrided) {
1312 return parser.emitError(parser.getNameLoc(),
1313 "expected two stride related operands");
1314 }
1315
1316 if (parser.parseColonTypeList(types))
1317 return failure();
1318 if (types.size() != 3)
1319 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1320
1321 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1322 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1323 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1324 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1325 // size should be an index.
1326 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1327 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1328 // tag indices should be index.
1329 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1330 return failure();
1331
1332 if (isStrided) {
1333 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1334 return failure();
1335 }
1336
1337 return success();
1338}
1339
1340LogicalResult DmaStartOp::verify() {
1341 unsigned numOperands = getNumOperands();
1342
1343 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1344 // the number of elements.
1345 if (numOperands < 4)
1346 return emitOpError("expected at least 4 operands");
1347
1348 // Check types of operands. The order of these calls is important: the later
1349 // calls rely on some type properties to compute the operand position.
1350 // 1. Source memref.
1351 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1352 return emitOpError("expected source to be of memref type");
1353 if (numOperands < getSrcMemRefRank() + 4)
1354 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1355 << " operands";
1356 if (!getSrcIndices().empty() &&
1357 !llvm::all_of(getSrcIndices().getTypes(),
1358 [](Type t) { return t.isIndex(); }))
1359 return emitOpError("expected source indices to be of index type");
1360
1361 // 2. Destination memref.
1362 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1363 return emitOpError("expected destination to be of memref type");
1364 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1365 if (numOperands < numExpectedOperands)
1366 return emitOpError() << "expected at least " << numExpectedOperands
1367 << " operands";
1368 if (!getDstIndices().empty() &&
1369 !llvm::all_of(getDstIndices().getTypes(),
1370 [](Type t) { return t.isIndex(); }))
1371 return emitOpError("expected destination indices to be of index type");
1372
1373 // 3. Number of elements.
1374 if (!getNumElements().getType().isIndex())
1375 return emitOpError("expected num elements to be of index type");
1376
1377 // 4. Tag memref.
1378 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1379 return emitOpError("expected tag to be of memref type");
1380 numExpectedOperands += getTagMemRefRank();
1381 if (numOperands < numExpectedOperands)
1382 return emitOpError() << "expected at least " << numExpectedOperands
1383 << " operands";
1384 if (!getTagIndices().empty() &&
1385 !llvm::all_of(getTagIndices().getTypes(),
1386 [](Type t) { return t.isIndex(); }))
1387 return emitOpError("expected tag indices to be of index type");
1388
1389 // Optional stride-related operands must be either both present or both
1390 // absent.
1391 if (numOperands != numExpectedOperands &&
1392 numOperands != numExpectedOperands + 2)
1393 return emitOpError("incorrect number of operands");
1394
1395 // 5. Strides.
1396 if (isStrided()) {
1397 if (!getStride().getType().isIndex() ||
1398 !getNumElementsPerStride().getType().isIndex())
1399 return emitOpError(
1400 "expected stride and num elements per stride to be of type index");
1401 }
1402
1403 return success();
1404}
1405
1406LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1407 SmallVectorImpl<OpFoldResult> &results) {
1408 /// dma_start(memrefcast) -> dma_start
1409 return foldMemRefCast(*this);
1410}
1411
1412void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
1413 ValueRange newSrcIndices, Value newDst,
1414 ValueRange newDstIndices) {
1415 /// dma_start has special handling for variadic rank
1416 SmallVector<Value> newOperands;
1417 newOperands.push_back(newSrc);
1418 llvm::append_range(newOperands, newSrcIndices);
1419 newOperands.push_back(newDst);
1420 llvm::append_range(newOperands, newDstIndices);
1421 newOperands.push_back(getNumElements());
1422 newOperands.push_back(getTagMemRef());
1423 llvm::append_range(newOperands, getTagIndices());
1424 if (isStrided()) {
1425 newOperands.push_back(getStride());
1426 newOperands.push_back(getNumElementsPerStride());
1427 }
1428
1429 rewriter.modifyOpInPlace(*this, [&]() { (*this)->setOperands(newOperands); });
1430}
1431
1432// ---------------------------------------------------------------------------
1433// DmaWaitOp
1434// ---------------------------------------------------------------------------
1435
1436LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1437 SmallVectorImpl<OpFoldResult> &results) {
1438 /// dma_wait(memrefcast) -> dma_wait
1439 return foldMemRefCast(*this);
1440}
1441
1442LogicalResult DmaWaitOp::verify() {
1443 // Check that the number of tag indices matches the tagMemRef rank.
1444 unsigned numTagIndices = getTagIndices().size();
1445 unsigned tagMemRefRank = getTagMemRefRank();
1446 if (numTagIndices != tagMemRefRank)
1447 return emitOpError() << "expected tagIndices to have the same number of "
1448 "elements as the tagMemRef rank, expected "
1449 << tagMemRefRank << ", but got " << numTagIndices;
1450 return success();
1451}
1452
1453//===----------------------------------------------------------------------===//
1454// ExtractAlignedPointerAsIndexOp
1455//===----------------------------------------------------------------------===//
1456
1457void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1458 function_ref<void(Value, StringRef)> setNameFn) {
1459 setNameFn(getResult(), "intptr");
1460}
1461
1462//===----------------------------------------------------------------------===//
1463// ExtractStridedMetadataOp
1464//===----------------------------------------------------------------------===//
1465
1466/// The number and type of the results are inferred from the
1467/// shape of the source.
1468LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1469 MLIRContext *context, std::optional<Location> location,
1470 ExtractStridedMetadataOp::Adaptor adaptor,
1471 SmallVectorImpl<Type> &inferredReturnTypes) {
1472 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1473 if (!sourceType)
1474 return failure();
1475
1476 unsigned sourceRank = sourceType.getRank();
1477 IndexType indexType = IndexType::get(context);
1478 auto memrefType =
1479 MemRefType::get({}, sourceType.getElementType(),
1480 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1481 // Base.
1482 inferredReturnTypes.push_back(memrefType);
1483 // Offset.
1484 inferredReturnTypes.push_back(indexType);
1485 // Sizes and strides.
1486 for (unsigned i = 0; i < sourceRank * 2; ++i)
1487 inferredReturnTypes.push_back(indexType);
1488 return success();
1489}
1490
1491void ExtractStridedMetadataOp::getAsmResultNames(
1492 function_ref<void(Value, StringRef)> setNameFn) {
1493 setNameFn(getBaseBuffer(), "base_buffer");
1494 setNameFn(getOffset(), "offset");
1495 // For multi-result to work properly with pretty names and packed syntax `x:3`
1496 // we can only give a pretty name to the first value in the pack.
1497 if (!getSizes().empty()) {
1498 setNameFn(getSizes().front(), "sizes");
1499 setNameFn(getStrides().front(), "strides");
1500 }
1501}
1502
1503/// Helper function to perform the replacement of all constant uses of `values`
1504/// by a materialized constant extracted from `maybeConstants`.
1505/// `values` and `maybeConstants` are expected to have the same size.
1506template <typename Container>
1507static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1508 Container values,
1509 ArrayRef<OpFoldResult> maybeConstants) {
1510 assert(values.size() == maybeConstants.size() &&
1511 " expected values and maybeConstants of the same size");
1512 bool atLeastOneReplacement = false;
1513 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1514 // Don't materialize a constant if there are no uses: this would indice
1515 // infinite loops in the driver.
1516 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1517 continue;
1518 assert(isa<Attribute>(maybeConstant) &&
1519 "The constified value should be either unchanged (i.e., == result) "
1520 "or a constant");
1522 rewriter, loc,
1523 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1524 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1525 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1526 // yet.
1527 op->replaceUsesOfWith(result, constantVal);
1528 atLeastOneReplacement = true;
1529 }
1530 }
1531 return atLeastOneReplacement;
1532}
1533
1534LogicalResult
1535ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1536 SmallVectorImpl<OpFoldResult> &results) {
1537 OpBuilder builder(*this);
1538
1539 bool atLeastOneReplacement = replaceConstantUsesOf(
1540 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1541 getConstifiedMixedOffset());
1542 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1543 getConstifiedMixedSizes());
1544 atLeastOneReplacement |= replaceConstantUsesOf(
1545 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1546
1547 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1548 if (auto prev = getSource().getDefiningOp<CastOp>())
1549 if (isa<MemRefType>(prev.getSource().getType())) {
1550 getSourceMutable().assign(prev.getSource());
1551 atLeastOneReplacement = true;
1552 }
1553
1554 return success(atLeastOneReplacement);
1555}
1556
1557SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1558 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1559 constifyIndexValues(values, getSource().getType().getShape());
1560 return values;
1561}
1562
1563SmallVector<OpFoldResult>
1564ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1565 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1566 SmallVector<int64_t> staticValues;
1567 int64_t unused;
1568 LogicalResult status =
1569 getSource().getType().getStridesAndOffset(staticValues, unused);
1570 (void)status;
1571 assert(succeeded(status) && "could not get strides from type");
1572 constifyIndexValues(values, staticValues);
1573 return values;
1574}
1575
1576OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1577 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1578 SmallVector<OpFoldResult> values(1, offsetOfr);
1579 SmallVector<int64_t> staticValues, unused;
1580 int64_t offset;
1581 LogicalResult status =
1582 getSource().getType().getStridesAndOffset(unused, offset);
1583 (void)status;
1584 assert(succeeded(status) && "could not get offset from type");
1585 staticValues.push_back(offset);
1586 constifyIndexValues(values, staticValues);
1587 return values[0];
1588}
1589
1590//===----------------------------------------------------------------------===//
1591// GenericAtomicRMWOp
1592//===----------------------------------------------------------------------===//
1593
1594void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1595 Value memref, ValueRange ivs) {
1596 OpBuilder::InsertionGuard g(builder);
1597 result.addOperands(memref);
1598 result.addOperands(ivs);
1599
1600 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1601 Type elementType = memrefType.getElementType();
1602 result.addTypes(elementType);
1603
1604 Region *bodyRegion = result.addRegion();
1605 builder.createBlock(bodyRegion);
1606 bodyRegion->addArgument(elementType, memref.getLoc());
1607 }
1608}
1609
1610LogicalResult GenericAtomicRMWOp::verify() {
1611 auto &body = getRegion();
1612 if (body.getNumArguments() != 1)
1613 return emitOpError("expected single number of entry block arguments");
1614
1615 if (getResult().getType() != body.getArgument(0).getType())
1616 return emitOpError("expected block argument of the same type result type");
1617
1618 bool hasSideEffects =
1619 body.walk([&](Operation *nestedOp) {
1620 if (isMemoryEffectFree(nestedOp))
1621 return WalkResult::advance();
1622 nestedOp->emitError(
1623 "body of 'memref.generic_atomic_rmw' should contain "
1624 "only operations with no side effects");
1625 return WalkResult::interrupt();
1626 })
1627 .wasInterrupted();
1628 return hasSideEffects ? failure() : success();
1629}
1630
1631ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1632 OperationState &result) {
1633 OpAsmParser::UnresolvedOperand memref;
1634 Type memrefType;
1635 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1636
1637 Type indexType = parser.getBuilder().getIndexType();
1638 if (parser.parseOperand(memref) ||
1640 parser.parseColonType(memrefType) ||
1641 parser.resolveOperand(memref, memrefType, result.operands) ||
1642 parser.resolveOperands(ivs, indexType, result.operands))
1643 return failure();
1644
1645 Region *body = result.addRegion();
1646 if (parser.parseRegion(*body, {}) ||
1647 parser.parseOptionalAttrDict(result.attributes))
1648 return failure();
1649 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1650 return success();
1651}
1652
1653void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1654 p << ' ' << getMemref() << "[" << getIndices()
1655 << "] : " << getMemref().getType() << ' ';
1656 p.printRegion(getRegion());
1657 p.printOptionalAttrDict((*this)->getAttrs());
1658}
1659
1660TypedValue<MemRefType> GenericAtomicRMWOp::getAccessedMemref() {
1661 return getMemref();
1662}
1663
1664std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
1665 RewriterBase &rewriter, Value newMemref, ValueRange newIndices) {
1666 rewriter.modifyOpInPlace(*this, [&]() {
1667 getMemrefMutable().assign(newMemref);
1668 getIndicesMutable().assign(newIndices);
1669 });
1670 return std::nullopt;
1671}
1672
1673//===----------------------------------------------------------------------===//
1674// AtomicYieldOp
1675//===----------------------------------------------------------------------===//
1676
1677LogicalResult AtomicYieldOp::verify() {
1678 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1679 Type resultType = getResult().getType();
1680 if (parentType != resultType)
1681 return emitOpError() << "types mismatch between yield op: " << resultType
1682 << " and its parent: " << parentType;
1683 return success();
1684}
1685
1686//===----------------------------------------------------------------------===//
1687// GlobalOp
1688//===----------------------------------------------------------------------===//
1689
1691 TypeAttr type,
1692 Attribute initialValue) {
1693 p << type;
1694 if (!op.isExternal()) {
1695 p << " = ";
1696 if (op.isUninitialized())
1697 p << "uninitialized";
1698 else
1699 p.printAttributeWithoutType(initialValue);
1700 }
1701}
1702
1703static ParseResult
1705 Attribute &initialValue) {
1706 Type type;
1707 if (parser.parseType(type))
1708 return failure();
1709
1710 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1711 if (!memrefType || !memrefType.hasStaticShape())
1712 return parser.emitError(parser.getNameLoc())
1713 << "type should be static shaped memref, but got " << type;
1714 typeAttr = TypeAttr::get(type);
1715
1716 if (parser.parseOptionalEqual())
1717 return success();
1718
1719 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1720 initialValue = UnitAttr::get(parser.getContext());
1721 return success();
1722 }
1723
1724 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1725 if (parser.parseAttribute(initialValue, tensorType))
1726 return failure();
1727 if (!llvm::isa<ElementsAttr>(initialValue))
1728 return parser.emitError(parser.getNameLoc())
1729 << "initial value should be a unit or elements attribute";
1730 return success();
1731}
1732
1733LogicalResult GlobalOp::verify() {
1734 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1735 if (!memrefType || !memrefType.hasStaticShape())
1736 return emitOpError("type should be static shaped memref, but got ")
1737 << getType();
1738
1739 // Verify that the initial value, if present, is either a unit attribute or
1740 // an elements attribute.
1741 if (getInitialValue().has_value()) {
1742 Attribute initValue = getInitialValue().value();
1743 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1744 return emitOpError("initial value should be a unit or elements "
1745 "attribute, but got ")
1746 << initValue;
1747
1748 // Check that the type of the initial value is compatible with the type of
1749 // the global variable.
1750 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1751 // Check the element types match.
1752 auto initElementType =
1753 cast<TensorType>(elementsAttr.getType()).getElementType();
1754 auto memrefElementType = memrefType.getElementType();
1755
1756 if (initElementType != memrefElementType)
1757 return emitOpError("initial value element expected to be of type ")
1758 << memrefElementType << ", but was of type " << initElementType;
1759
1760 // Check the shapes match, given that memref globals can only produce
1761 // statically shaped memrefs and elements literal type must have a static
1762 // shape we can assume both types are shaped.
1763 auto initShape = elementsAttr.getShapedType().getShape();
1764 auto memrefShape = memrefType.getShape();
1765 if (initShape != memrefShape)
1766 return emitOpError("initial value shape expected to be ")
1767 << memrefShape << " but was " << initShape;
1768 }
1769 }
1770
1771 // TODO: verify visibility for declarations.
1772 return success();
1773}
1774
1775ElementsAttr GlobalOp::getConstantInitValue() {
1776 auto initVal = getInitialValue();
1777 if (getConstant() && initVal.has_value())
1778 return llvm::cast<ElementsAttr>(initVal.value());
1779 return {};
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// GetGlobalOp
1784//===----------------------------------------------------------------------===//
1785
1786LogicalResult
1787GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1788 // Verify that the result type is same as the type of the referenced
1789 // memref.global op.
1790 auto global =
1791 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1792 if (!global)
1793 return emitOpError("'")
1794 << getName() << "' does not reference a valid global memref";
1795
1796 Type resultType = getResult().getType();
1797 if (global.getType() != resultType)
1798 return emitOpError("result type ")
1799 << resultType << " does not match type " << global.getType()
1800 << " of the global memref @" << getName();
1801 return success();
1802}
1803
1804//===----------------------------------------------------------------------===//
1805// LoadOp
1806//===----------------------------------------------------------------------===//
1807
1808OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1809 /// load(memrefcast) -> load
1810 if (succeeded(foldMemRefCast(*this)))
1811 return getResult();
1812
1813 // Fold load from a global constant memref.
1814 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1815 if (!getGlobalOp)
1816 return {};
1817
1818 // Get to the memref.global defining the symbol.
1820 getGlobalOp, getGlobalOp.getNameAttr());
1821 if (!global)
1822 return {};
1823 // If it's a splat constant, we can fold irrespective of indices.
1824 auto splatAttr =
1825 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1826 if (!splatAttr)
1827 return {};
1828
1829 return splatAttr.getSplatValue<Attribute>();
1830}
1831
1832TypedValue<MemRefType> LoadOp::getAccessedMemref() { return getMemref(); }
1833
1834std::optional<SmallVector<Value>>
1835LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1836 ValueRange newIndices) {
1837 rewriter.modifyOpInPlace(*this, [&]() {
1838 getMemrefMutable().assign(newMemref);
1839 getIndicesMutable().assign(newIndices);
1840 });
1841 return std::nullopt;
1842}
1843
1844FailureOr<std::optional<SmallVector<Value>>>
1845LoadOp::bubbleDownCasts(OpBuilder &builder) {
1847 getResult());
1848}
1849
1850//===----------------------------------------------------------------------===//
1851// MemorySpaceCastOp
1852//===----------------------------------------------------------------------===//
1853
1854void MemorySpaceCastOp::getAsmResultNames(
1855 function_ref<void(Value, StringRef)> setNameFn) {
1856 setNameFn(getResult(), "memspacecast");
1857}
1858
1859bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1860 if (inputs.size() != 1 || outputs.size() != 1)
1861 return false;
1862 Type a = inputs.front(), b = outputs.front();
1863 auto aT = llvm::dyn_cast<MemRefType>(a);
1864 auto bT = llvm::dyn_cast<MemRefType>(b);
1865
1866 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1867 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1868
1869 if (aT && bT) {
1870 if (aT.getElementType() != bT.getElementType())
1871 return false;
1872 if (aT.getLayout() != bT.getLayout())
1873 return false;
1874 if (aT.getShape() != bT.getShape())
1875 return false;
1876 return true;
1877 }
1878 if (uaT && ubT) {
1879 return uaT.getElementType() == ubT.getElementType();
1880 }
1881 return false;
1882}
1883
1884OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1885 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1886 // t2)
1887 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1888 getSourceMutable().assign(parentCast.getSource());
1889 return getResult();
1890 }
1891 return Value{};
1892}
1893
1894TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1895 return getSource();
1896}
1897
1898TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1899 return getDest();
1900}
1901
1902bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1903 PtrLikeTypeInterface src) {
1904 return isa<BaseMemRefType>(tgt) &&
1905 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1906}
1907
1908MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1909 OpBuilder &b, PtrLikeTypeInterface tgt,
1911 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1912 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1913}
1914
1915/// The only cast we recognize as promotable is to the generic space.
1916bool MemorySpaceCastOp::isSourcePromotable() {
1917 return getDest().getType().getMemorySpace() == nullptr;
1918}
1919
1920//===----------------------------------------------------------------------===//
1921// PrefetchOp
1922//===----------------------------------------------------------------------===//
1923
1924void PrefetchOp::print(OpAsmPrinter &p) {
1925 p << " " << getMemref() << '[';
1927 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1928 p << ", locality<" << getLocalityHint();
1929 p << ">, " << (getIsDataCache() ? "data" : "instr");
1931 (*this)->getAttrs(),
1932 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1933 p << " : " << getMemRefType();
1934}
1935
1936ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1937 OpAsmParser::UnresolvedOperand memrefInfo;
1938 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1939 IntegerAttr localityHint;
1940 MemRefType type;
1941 StringRef readOrWrite, cacheType;
1942
1943 auto indexTy = parser.getBuilder().getIndexType();
1944 auto i32Type = parser.getBuilder().getIntegerType(32);
1945 if (parser.parseOperand(memrefInfo) ||
1947 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1948 parser.parseComma() || parser.parseKeyword("locality") ||
1949 parser.parseLess() ||
1950 parser.parseAttribute(localityHint, i32Type, "localityHint",
1951 result.attributes) ||
1952 parser.parseGreater() || parser.parseComma() ||
1953 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1954 parser.resolveOperand(memrefInfo, type, result.operands) ||
1955 parser.resolveOperands(indexInfo, indexTy, result.operands))
1956 return failure();
1957
1958 if (readOrWrite != "read" && readOrWrite != "write")
1959 return parser.emitError(parser.getNameLoc(),
1960 "rw specifier has to be 'read' or 'write'");
1961 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1962 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1963
1964 if (cacheType != "data" && cacheType != "instr")
1965 return parser.emitError(parser.getNameLoc(),
1966 "cache type has to be 'data' or 'instr'");
1967
1968 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1969 parser.getBuilder().getBoolAttr(cacheType == "data"));
1970
1971 return success();
1972}
1973
1974LogicalResult PrefetchOp::verify() {
1975 if (getNumOperands() != 1 + getMemRefType().getRank())
1976 return emitOpError("too few indices");
1977
1978 return success();
1979}
1980
1981LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1983 // prefetch(memrefcast) -> prefetch
1984 return foldMemRefCast(*this);
1985}
1986
1987TypedValue<MemRefType> PrefetchOp::getAccessedMemref() { return getMemref(); }
1988
1989std::optional<SmallVector<Value>>
1990PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1991 ValueRange newIndices) {
1992 rewriter.modifyOpInPlace(*this, [&]() {
1993 getMemrefMutable().assign(newMemref);
1994 getIndicesMutable().assign(newIndices);
1995 });
1996 return std::nullopt;
1997}
1998
1999//===----------------------------------------------------------------------===//
2000// RankOp
2001//===----------------------------------------------------------------------===//
2002
2003OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
2004 // Constant fold rank when the rank of the operand is known.
2005 auto type = getOperand().getType();
2006 auto shapedType = llvm::dyn_cast<ShapedType>(type);
2007 if (shapedType && shapedType.hasRank())
2008 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
2009 return IntegerAttr();
2010}
2011
2012//===----------------------------------------------------------------------===//
2013// ReinterpretCastOp
2014//===----------------------------------------------------------------------===//
2015
2016void ReinterpretCastOp::getAsmResultNames(
2017 function_ref<void(Value, StringRef)> setNameFn) {
2018 setNameFn(getResult(), "reinterpret_cast");
2019}
2020
2021/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
2022/// `staticSizes` and `staticStrides` are automatically filled with
2023/// source-memref-rank sentinel values that encode dynamic entries.
2024void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2025 MemRefType resultType, Value source,
2026 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
2027 ArrayRef<OpFoldResult> strides,
2028 ArrayRef<NamedAttribute> attrs) {
2029 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2030 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2031 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
2032 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2033 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2034 result.addAttributes(attrs);
2035 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2036 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2037 b.getDenseI64ArrayAttr(staticSizes),
2038 b.getDenseI64ArrayAttr(staticStrides));
2039}
2040
2041void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2042 Value source, OpFoldResult offset,
2043 ArrayRef<OpFoldResult> sizes,
2044 ArrayRef<OpFoldResult> strides,
2045 ArrayRef<NamedAttribute> attrs) {
2046 auto sourceType = cast<BaseMemRefType>(source.getType());
2047 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2048 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2049 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
2050 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2051 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2052 auto stridedLayout = StridedLayoutAttr::get(
2053 b.getContext(), staticOffsets.front(), staticStrides);
2054 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2055 stridedLayout, sourceType.getMemorySpace());
2056 build(b, result, resultType, source, offset, sizes, strides, attrs);
2057}
2058
2059void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2060 MemRefType resultType, Value source,
2061 int64_t offset, ArrayRef<int64_t> sizes,
2062 ArrayRef<int64_t> strides,
2063 ArrayRef<NamedAttribute> attrs) {
2064 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2065 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
2066 SmallVector<OpFoldResult> strideValues =
2067 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2068 return b.getI64IntegerAttr(v);
2069 });
2070 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
2071 strideValues, attrs);
2072}
2073
2074void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2075 MemRefType resultType, Value source, Value offset,
2076 ValueRange sizes, ValueRange strides,
2077 ArrayRef<NamedAttribute> attrs) {
2078 SmallVector<OpFoldResult> sizeValues =
2079 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2080 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2081 strides, [](Value v) -> OpFoldResult { return v; });
2082 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
2083}
2084
2085// TODO: ponder whether we want to allow missing trailing sizes/strides that are
2086// completed automatically, like we have for subview and extract_slice.
2087LogicalResult ReinterpretCastOp::verify() {
2088 // The source and result memrefs should be in the same memory space.
2089 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
2090 auto resultType = llvm::cast<MemRefType>(getType());
2091 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2092 return emitError("different memory spaces specified for source type ")
2093 << srcType << " and result memref type " << resultType;
2094 if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
2095 "result")))
2096 return failure();
2097
2098 // Match sizes in result memref type and in static_sizes attribute.
2099 for (auto [idx, resultSize, expectedSize] :
2100 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2101 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2102 return emitError("expected result type with size = ")
2103 << (ShapedType::isDynamic(expectedSize)
2104 ? std::string("dynamic")
2105 : std::to_string(expectedSize))
2106 << " instead of " << resultSize << " in dim = " << idx;
2107 }
2108
2109 // Match offset and strides in static_offset and static_strides attributes. If
2110 // result memref type has no affine map specified, this will assume an
2111 // identity layout.
2112 int64_t resultOffset;
2113 SmallVector<int64_t, 4> resultStrides;
2114 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2115 return emitError("expected result type to have strided layout but found ")
2116 << resultType;
2117
2118 // Match offset in result memref type and in static_offsets attribute.
2119 int64_t expectedOffset = getStaticOffsets().front();
2120 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2121 return emitError("expected result type with offset = ")
2122 << (ShapedType::isDynamic(expectedOffset)
2123 ? std::string("dynamic")
2124 : std::to_string(expectedOffset))
2125 << " instead of " << resultOffset;
2126
2127 // Match strides in result memref type and in static_strides attribute.
2128 for (auto [idx, resultStride, expectedStride] :
2129 llvm::enumerate(resultStrides, getStaticStrides())) {
2130 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2131 return emitError("expected result type with stride = ")
2132 << (ShapedType::isDynamic(expectedStride)
2133 ? std::string("dynamic")
2134 : std::to_string(expectedStride))
2135 << " instead of " << resultStride << " in dim = " << idx;
2136 }
2137
2138 return success();
2139}
2140
2141OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
2142 Value src = getSource();
2143 auto getPrevSrc = [&]() -> Value {
2144 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
2145 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
2146 return prev.getSource();
2147
2148 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
2149 if (auto prev = src.getDefiningOp<CastOp>())
2150 return prev.getSource();
2151
2152 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2153 // are 0.
2154 if (auto prev = src.getDefiningOp<SubViewOp>())
2155 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2156 return prev.getSource();
2157
2158 return nullptr;
2159 };
2160
2161 if (auto prevSrc = getPrevSrc()) {
2162 getSourceMutable().assign(prevSrc);
2163 return getResult();
2164 }
2165
2166 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2167 if (ShapedType::isStaticShape(getType().getShape()) &&
2168 src.getType() == getType() && getStaticOffsets().front() == 0) {
2169 return src;
2170 }
2171
2172 return nullptr;
2173}
2174
2175SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2176 SmallVector<OpFoldResult> values = getMixedSizes();
2178 return values;
2179}
2180
2181SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2182 SmallVector<OpFoldResult> values = getMixedStrides();
2183 SmallVector<int64_t> staticValues;
2184 int64_t unused;
2185 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2186 (void)status;
2187 assert(succeeded(status) && "could not get strides from type");
2188 constifyIndexValues(values, staticValues);
2189 return values;
2190}
2191
2192OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2193 SmallVector<OpFoldResult> values = getMixedOffsets();
2194 assert(values.size() == 1 &&
2195 "reinterpret_cast must have one and only one offset");
2196 SmallVector<int64_t> staticValues, unused;
2197 int64_t offset;
2198 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2199 (void)status;
2200 assert(succeeded(status) && "could not get offset from type");
2201 staticValues.push_back(offset);
2202 constifyIndexValues(values, staticValues);
2203 return values[0];
2204}
2205
2206namespace {
2207/// Replace the sequence:
2208/// ```
2209/// base, offset, sizes, strides = extract_strided_metadata src
2210/// dst = reinterpret_cast base to offset, sizes, strides
2211/// ```
2212/// With
2213///
2214/// ```
2215/// dst = memref.cast src
2216/// ```
2217///
2218/// Note: The cast operation is only inserted when the type of dst and src
2219/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2220///
2221/// This pattern also matches when the offset, sizes, and strides don't come
2222/// directly from the `extract_strided_metadata`'s results but it can be
2223/// statically proven that they would hold the same values.
2224///
2225/// For instance, the following sequence would be replaced:
2226/// ```
2227/// base, offset, sizes, strides =
2228/// extract_strided_metadata memref : memref<3x4xty>
2229/// dst = reinterpret_cast base to 0, [3, 4], strides
2230/// ```
2231/// Because we know (thanks to the type of the input memref) that variable
2232/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2233///
2234/// Similarly, the following sequence would be replaced:
2235/// ```
2236/// c0 = arith.constant 0
2237/// c4 = arith.constant 4
2238/// base, offset, sizes, strides =
2239/// extract_strided_metadata memref : memref<3x4xty>
2240/// dst = reinterpret_cast base to c0, [3, c4], strides
2241/// ```
2242/// Because we know that `offset`and `c0` will hold 0
2243/// and `c4` will hold 4.
2244///
2245/// If the pattern above does not match, the input of the
2246/// extract_strided_metadata is always folded into the input of the
2247/// reinterpret_cast operator. This allows for dead code elimination to get rid
2248/// of the extract_strided_metadata in some cases.
2249struct ReinterpretCastOpExtractStridedMetadataFolder
2250 : public OpRewritePattern<ReinterpretCastOp> {
2251public:
2252 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2253
2254 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2255 PatternRewriter &rewriter) const override {
2256 auto extractStridedMetadata =
2257 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2258 if (!extractStridedMetadata)
2259 return failure();
2260
2261 // Check if the reinterpret cast reconstructs a memref with the exact same
2262 // properties as the extract strided metadata.
2263 auto isReinterpretCastNoop = [&]() -> bool {
2264 // First, check that the strides are the same.
2265 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2266 op.getConstifiedMixedStrides()))
2267 return false;
2268
2269 // Second, check the sizes.
2270 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2271 op.getConstifiedMixedSizes()))
2272 return false;
2273
2274 // Finally, check the offset.
2275 assert(op.getMixedOffsets().size() == 1 &&
2276 "reinterpret_cast with more than one offset should have been "
2277 "rejected by the verifier");
2278 return extractStridedMetadata.getConstifiedMixedOffset() ==
2279 op.getConstifiedMixedOffset();
2280 };
2281
2282 if (!isReinterpretCastNoop()) {
2283 // If the extract_strided_metadata / reinterpret_cast pair can't be
2284 // completely folded, then we could fold the input of the
2285 // extract_strided_metadata into the input of the reinterpret_cast
2286 // input. For some cases (e.g., static dimensions) the
2287 // the extract_strided_metadata is eliminated by dead code elimination.
2288 //
2289 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2290 //
2291 // We can always fold the input of a extract_strided_metadata operator
2292 // to the input of a reinterpret_cast operator, because they point to
2293 // the same memory. Note that the reinterpret_cast does not use the
2294 // layout of its input memref, only its base memory pointer which is
2295 // the same as the base pointer returned by the extract_strided_metadata
2296 // operator and the base pointer of the extract_strided_metadata memref
2297 // input.
2298 rewriter.modifyOpInPlace(op, [&]() {
2299 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2300 });
2301 return success();
2302 }
2303
2304 // At this point, we know that the back and forth between extract strided
2305 // metadata and reinterpret cast is a noop. However, the final type of the
2306 // reinterpret cast may not be exactly the same as the original memref.
2307 // E.g., it could be changing a dimension from static to dynamic. Check that
2308 // here and add a cast if necessary.
2309 Type srcTy = extractStridedMetadata.getSource().getType();
2310 if (srcTy == op.getResult().getType())
2311 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2312 else
2313 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2314 extractStridedMetadata.getSource());
2315
2316 return success();
2317 }
2318};
2319
2320struct ReinterpretCastOpConstantFolder
2321 : public OpRewritePattern<ReinterpretCastOp> {
2322public:
2323 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2324
2325 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2326 PatternRewriter &rewriter) const override {
2327 unsigned srcStaticCount = llvm::count_if(
2328 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2329 op.getMixedStrides()),
2330 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2331
2332 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2333 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2334 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2335
2336 // If the offset is a negative constant, we can't fold it because the
2337 // resulting memref type would be invalid. In that case, we keep the
2338 // original offset.
2339 if (auto cst = getConstantIntValue(offsets[0]))
2340 if (*cst < 0)
2341 offsets[0] = op.getMixedOffsets()[0];
2342
2343 // If the size is a negative constant, we can't fold it because the
2344 // resulting memref type would be invalid. In that case, we keep the
2345 // original size.
2346 for (auto it : llvm::zip(op.getMixedSizes(), sizes)) {
2347 auto &srcSizeOfr = std::get<0>(it);
2348 auto &sizeOfr = std::get<1>(it);
2349 if (auto cst = getConstantIntValue(sizeOfr))
2350 if (*cst < 0)
2351 sizeOfr = srcSizeOfr;
2352 }
2353
2354 // TODO: Using counting comparison instead of direct comparison because
2355 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2356 // IntegerAttrs, while constifyIndexValues (and therefore
2357 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2358 if (srcStaticCount ==
2359 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2360 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2361 return failure();
2362
2363 auto newReinterpretCast = ReinterpretCastOp::create(
2364 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2365
2366 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2367 return success();
2368 }
2369};
2370} // namespace
2371
2372void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2373 MLIRContext *context) {
2374 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2375 ReinterpretCastOpConstantFolder>(context);
2376}
2377
2378FailureOr<std::optional<SmallVector<Value>>>
2379ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2380 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2381}
2382
2383//===----------------------------------------------------------------------===//
2384// Reassociative reshape ops
2385//===----------------------------------------------------------------------===//
2386
2387void CollapseShapeOp::getAsmResultNames(
2388 function_ref<void(Value, StringRef)> setNameFn) {
2389 setNameFn(getResult(), "collapse_shape");
2390}
2391
2392void ExpandShapeOp::getAsmResultNames(
2393 function_ref<void(Value, StringRef)> setNameFn) {
2394 setNameFn(getResult(), "expand_shape");
2395}
2396
2397LogicalResult ExpandShapeOp::reifyResultShapes(
2398 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2399 reifiedResultShapes = {
2400 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2401 return success();
2402}
2403
2404/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2405/// result and operand. Layout maps are verified separately.
2406///
2407/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2408/// allowed in a reassocation group.
2409static LogicalResult
2411 ArrayRef<int64_t> expandedShape,
2412 ArrayRef<ReassociationIndices> reassociation,
2413 bool allowMultipleDynamicDimsPerGroup) {
2414 // There must be one reassociation group per collapsed dimension.
2415 if (collapsedShape.size() != reassociation.size())
2416 return op->emitOpError("invalid number of reassociation groups: found ")
2417 << reassociation.size() << ", expected " << collapsedShape.size();
2418
2419 // The next expected expanded dimension index (while iterating over
2420 // reassociation indices).
2421 int64_t nextDim = 0;
2422 for (const auto &it : llvm::enumerate(reassociation)) {
2423 ReassociationIndices group = it.value();
2424 int64_t collapsedDim = it.index();
2425
2426 bool foundDynamic = false;
2427 for (int64_t expandedDim : group) {
2428 if (expandedDim != nextDim++)
2429 return op->emitOpError("reassociation indices must be contiguous");
2430
2431 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2432 return op->emitOpError("reassociation index ")
2433 << expandedDim << " is out of bounds";
2434
2435 // Check if there are multiple dynamic dims in a reassociation group.
2436 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2437 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2438 return op->emitOpError(
2439 "at most one dimension in a reassociation group may be dynamic");
2440 foundDynamic = true;
2441 }
2442 }
2443
2444 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2445 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2446 return op->emitOpError("collapsed dim (")
2447 << collapsedDim
2448 << ") must be dynamic if and only if reassociation group is "
2449 "dynamic";
2450
2451 // If all dims in the reassociation group are static, the size of the
2452 // collapsed dim can be verified.
2453 if (!foundDynamic) {
2454 int64_t groupSize = 1;
2455 for (int64_t expandedDim : group)
2456 groupSize *= expandedShape[expandedDim];
2457 if (groupSize != collapsedShape[collapsedDim])
2458 return op->emitOpError("collapsed dim size (")
2459 << collapsedShape[collapsedDim]
2460 << ") must equal reassociation group size (" << groupSize << ")";
2461 }
2462 }
2463
2464 if (collapsedShape.empty()) {
2465 // Rank 0: All expanded dimensions must be 1.
2466 for (int64_t d : expandedShape)
2467 if (d != 1)
2468 return op->emitOpError(
2469 "rank 0 memrefs can only be extended/collapsed with/from ones");
2470 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2471 // Rank >= 1: Number of dimensions among all reassociation groups must match
2472 // the result memref rank.
2473 return op->emitOpError("expanded rank (")
2474 << expandedShape.size()
2475 << ") inconsistent with number of reassociation indices (" << nextDim
2476 << ")";
2477 }
2478
2479 return success();
2480}
2481
2482SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2483 return getSymbolLessAffineMaps(getReassociationExprs());
2484}
2485
2486SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2488 getReassociationIndices());
2489}
2490
2491SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2492 return getSymbolLessAffineMaps(getReassociationExprs());
2493}
2494
2495SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2497 getReassociationIndices());
2498}
2499
2500/// Compute the layout map after expanding a given source MemRef type with the
2501/// specified reassociation indices.
2502static FailureOr<StridedLayoutAttr>
2503computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2504 ArrayRef<ReassociationIndices> reassociation) {
2505 int64_t srcOffset;
2506 SmallVector<int64_t> srcStrides;
2507 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2508 return failure();
2509 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2510
2511 // 1-1 mapping between srcStrides and reassociation packs.
2512 // Each srcStride starts with the given value and gets expanded according to
2513 // the proper entries in resultShape.
2514 // Example:
2515 // srcStrides = [10000, 1 , 100 ],
2516 // reassociations = [ [0], [1], [2, 3, 4]],
2517 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2518 // -> For the purpose of stride calculation, the useful sizes are:
2519 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2520 // resultStrides = [10000, 1, 600, 200, 100]
2521 // Note that a stride does not get expanded along the first entry of each
2522 // shape pack.
2523 SmallVector<int64_t> reverseResultStrides;
2524 reverseResultStrides.reserve(resultShape.size());
2525 unsigned shapeIndex = resultShape.size() - 1;
2526 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2527 ReassociationIndices reassoc = std::get<0>(it);
2528 int64_t currentStrideToExpand = std::get<1>(it);
2529 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2530 reverseResultStrides.push_back(currentStrideToExpand);
2531 currentStrideToExpand =
2532 (SaturatedInteger::wrap(currentStrideToExpand) *
2533 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2534 .asInteger();
2535 }
2536 }
2537 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2538 resultStrides.resize(resultShape.size(), 1);
2539 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2540}
2541
2542FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2543 MemRefType srcType, ArrayRef<int64_t> resultShape,
2544 ArrayRef<ReassociationIndices> reassociation) {
2545 if (srcType.getLayout().isIdentity()) {
2546 // If the source is contiguous (i.e., no layout map specified), so is the
2547 // result.
2548 MemRefLayoutAttrInterface layout;
2549 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2550 srcType.getMemorySpace());
2551 }
2552
2553 // Source may not be contiguous. Compute the layout map.
2554 FailureOr<StridedLayoutAttr> computedLayout =
2555 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2556 if (failed(computedLayout))
2557 return failure();
2558 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2559 srcType.getMemorySpace());
2560}
2561
2562FailureOr<SmallVector<OpFoldResult>>
2563ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2564 MemRefType expandedType,
2565 ArrayRef<ReassociationIndices> reassociation,
2566 ArrayRef<OpFoldResult> inputShape) {
2567 std::optional<SmallVector<OpFoldResult>> outputShape =
2568 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2569 inputShape);
2570 if (!outputShape)
2571 return failure();
2572 return *outputShape;
2573}
2574
2575void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2576 Type resultType, Value src,
2577 ArrayRef<ReassociationIndices> reassociation,
2578 ArrayRef<OpFoldResult> outputShape) {
2579 auto [staticOutputShape, dynamicOutputShape] =
2580 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2581 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2582 getReassociationIndicesAttribute(builder, reassociation),
2583 dynamicOutputShape, staticOutputShape);
2584}
2585
2586void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2587 Type resultType, Value src,
2588 ArrayRef<ReassociationIndices> reassociation) {
2589 SmallVector<OpFoldResult> inputShape =
2590 getMixedSizes(builder, result.location, src);
2591 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2592 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2593 builder, result.location, memrefResultTy, reassociation, inputShape);
2594 // Failure of this assertion usually indicates presence of multiple
2595 // dynamic dimensions in the same reassociation group.
2596 assert(succeeded(outputShape) && "unable to infer output shape");
2597 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2598}
2599
2600void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2601 ArrayRef<int64_t> resultShape, Value src,
2602 ArrayRef<ReassociationIndices> reassociation) {
2603 // Only ranked memref source values are supported.
2604 auto srcType = llvm::cast<MemRefType>(src.getType());
2605 FailureOr<MemRefType> resultType =
2606 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2607 // Failure of this assertion usually indicates a problem with the source
2608 // type, e.g., could not get strides/offset.
2609 assert(succeeded(resultType) && "could not compute layout");
2610 build(builder, result, *resultType, src, reassociation);
2611}
2612
2613void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2614 ArrayRef<int64_t> resultShape, Value src,
2615 ArrayRef<ReassociationIndices> reassociation,
2616 ArrayRef<OpFoldResult> outputShape) {
2617 // Only ranked memref source values are supported.
2618 auto srcType = llvm::cast<MemRefType>(src.getType());
2619 FailureOr<MemRefType> resultType =
2620 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2621 // Failure of this assertion usually indicates a problem with the source
2622 // type, e.g., could not get strides/offset.
2623 assert(succeeded(resultType) && "could not compute layout");
2624 build(builder, result, *resultType, src, reassociation, outputShape);
2625}
2626
2627LogicalResult ExpandShapeOp::verify() {
2628 MemRefType srcType = getSrcType();
2629 MemRefType resultType = getResultType();
2630
2631 if (srcType.getRank() > resultType.getRank()) {
2632 auto r0 = srcType.getRank();
2633 auto r1 = resultType.getRank();
2634 return emitOpError("has source rank ")
2635 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2636 << r0 << " > " << r1 << ").";
2637 }
2638
2639 // Verify result shape.
2640 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2641 resultType.getShape(),
2642 getReassociationIndices(),
2643 /*allowMultipleDynamicDimsPerGroup=*/true)))
2644 return failure();
2645
2646 // Compute expected result type (including layout map).
2647 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2648 srcType, resultType.getShape(), getReassociationIndices());
2649 if (failed(expectedResultType))
2650 return emitOpError("invalid source layout map");
2651
2652 // Check actual result type.
2653 if (*expectedResultType != resultType)
2654 return emitOpError("expected expanded type to be ")
2655 << *expectedResultType << " but found " << resultType;
2656
2657 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2658 return emitOpError("expected number of static shape bounds to be equal to "
2659 "the output rank (")
2660 << resultType.getRank() << ") but found "
2661 << getStaticOutputShape().size() << " inputs instead";
2662
2663 if ((int64_t)getOutputShape().size() !=
2664 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2665 return emitOpError("mismatch in dynamic dims in output_shape and "
2666 "static_output_shape: static_output_shape has ")
2667 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2668 << " dynamic dims while output_shape has " << getOutputShape().size()
2669 << " values";
2670
2671 // Verify that the number of dynamic dims in output_shape matches the number
2672 // of dynamic dims in the result type.
2673 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2674 getOutputShape())))
2675 return failure();
2676
2677 // Verify if provided output shapes are in agreement with output type.
2678 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2679 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2680 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2681 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2682 return emitOpError("invalid output shape provided at pos ") << pos;
2683 }
2684 }
2685
2686 return success();
2687}
2688
2689struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
2690public:
2691 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2692
2693 LogicalResult matchAndRewrite(ExpandShapeOp op,
2694 PatternRewriter &rewriter) const override {
2695 auto cast = op.getSrc().getDefiningOp<CastOp>();
2696 if (!cast)
2697 return failure();
2698
2699 if (!CastOp::canFoldIntoConsumerOp(cast))
2700 return failure();
2701
2702 SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
2703 SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
2704 SmallVector<int64_t> newOutputShapeSizes;
2705
2706 // Convert output shape dims from dynamic to static where possible.
2707 for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2708 std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
2709 if (!sizeOpt.has_value()) {
2710 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2711 continue;
2712 }
2713
2714 newOutputShapeSizes.push_back(sizeOpt.value());
2715 newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
2716 }
2717
2718 Value castSource = cast.getSource();
2719 auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
2720 SmallVector<ReassociationIndices> reassociationIndices =
2721 op.getReassociationIndices();
2722 for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2723 auto newOutputShapeSizesSlice =
2724 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2725 bool newOutputDynamic =
2726 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2727 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2728 return rewriter.notifyMatchFailure(
2729 op, "folding cast will result in changing dynamicity in "
2730 "reassociation group");
2731 }
2732
2733 FailureOr<MemRefType> newResultTypeOrFailure =
2734 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2735 reassociationIndices);
2736
2737 if (failed(newResultTypeOrFailure))
2738 return rewriter.notifyMatchFailure(
2739 op, "could not compute new expanded type after folding cast");
2740
2741 if (*newResultTypeOrFailure == op.getResultType()) {
2742 rewriter.modifyOpInPlace(
2743 op, [&]() { op.getSrcMutable().assign(castSource); });
2744 } else {
2745 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2746 *newResultTypeOrFailure, castSource,
2747 reassociationIndices, newOutputShape);
2748 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2749 }
2750 return success();
2751 }
2752};
2753
2754void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2755 MLIRContext *context) {
2756 results.add<
2757 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2758 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2759 ExpandShapeOpMemRefCastFolder>(context);
2760}
2761
2762FailureOr<std::optional<SmallVector<Value>>>
2763ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2764 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2765}
2766
2767/// Compute the layout map after collapsing a given source MemRef type with the
2768/// specified reassociation indices.
2769///
2770/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2771/// not possible to check this by inspecting a MemRefType in the general case.
2772/// If non-contiguity cannot be checked statically, the collapse is assumed to
2773/// be valid (and thus accepted by this function) unless `strict = true`.
2774static FailureOr<StridedLayoutAttr>
2775computeCollapsedLayoutMap(MemRefType srcType,
2776 ArrayRef<ReassociationIndices> reassociation,
2777 bool strict = false) {
2778 int64_t srcOffset;
2779 SmallVector<int64_t> srcStrides;
2780 auto srcShape = srcType.getShape();
2781 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2782 return failure();
2783
2784 // The result stride of a reassociation group is the stride of the last entry
2785 // of the reassociation. (TODO: Should be the minimum stride in the
2786 // reassociation because strides are not necessarily sorted. E.g., when using
2787 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2788 // strides are meaningless and could have any arbitrary value.
2789 SmallVector<int64_t> resultStrides;
2790 resultStrides.reserve(reassociation.size());
2791 for (const ReassociationIndices &reassoc : reassociation) {
2792 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2793 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2794 ref = ref.drop_back();
2795 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2796 resultStrides.push_back(srcStrides[ref.back()]);
2797 } else {
2798 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2799 // the corresponding stride may have to be skipped. (See above comment.)
2800 // Therefore, the result stride cannot be statically determined and must
2801 // be dynamic.
2802 resultStrides.push_back(ShapedType::kDynamic);
2803 }
2804 }
2805
2806 // Validate that each reassociation group is contiguous.
2807 unsigned resultStrideIndex = resultStrides.size() - 1;
2808 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2809 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2810 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2811 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2812 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2813
2814 // Dimensions of size 1 should be skipped, because their strides are
2815 // meaningless and could have any arbitrary value.
2816 if (srcShape[idx - 1] == 1)
2817 continue;
2818
2819 // Both source and result stride must have the same static value. In that
2820 // case, we can be sure, that the dimensions are collapsible (because they
2821 // are contiguous).
2822 // If `strict = false` (default during op verification), we accept cases
2823 // where one or both strides are dynamic. This is best effort: We reject
2824 // ops where obviously non-contiguous dims are collapsed, but accept ops
2825 // where we cannot be sure statically. Such ops may fail at runtime. See
2826 // the op documentation for details.
2827 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2828 if (strict && (stride.saturated || srcStride.saturated))
2829 return failure();
2830
2831 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2832 return failure();
2833 }
2834 }
2835 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2836}
2837
2838bool CollapseShapeOp::isGuaranteedCollapsible(
2839 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2840 // MemRefs with identity layout are always collapsible.
2841 if (srcType.getLayout().isIdentity())
2842 return true;
2843
2844 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2845 /*strict=*/true));
2846}
2847
2848MemRefType CollapseShapeOp::computeCollapsedType(
2849 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2850 SmallVector<int64_t> resultShape;
2851 resultShape.reserve(reassociation.size());
2852 for (const ReassociationIndices &group : reassociation) {
2853 auto groupSize = SaturatedInteger::wrap(1);
2854 for (int64_t srcDim : group)
2855 groupSize =
2856 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2857 resultShape.push_back(groupSize.asInteger());
2858 }
2859
2860 if (srcType.getLayout().isIdentity()) {
2861 // If the source is contiguous (i.e., no layout map specified), so is the
2862 // result.
2863 MemRefLayoutAttrInterface layout;
2864 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2865 srcType.getMemorySpace());
2866 }
2867
2868 // Source may not be fully contiguous. Compute the layout map.
2869 // Note: Dimensions that are collapsed into a single dim are assumed to be
2870 // contiguous.
2871 FailureOr<StridedLayoutAttr> computedLayout =
2872 computeCollapsedLayoutMap(srcType, reassociation);
2873 assert(succeeded(computedLayout) &&
2874 "invalid source layout map or collapsing non-contiguous dims");
2875 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2876 srcType.getMemorySpace());
2877}
2878
2879void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2880 ArrayRef<ReassociationIndices> reassociation,
2881 ArrayRef<NamedAttribute> attrs) {
2882 auto srcType = llvm::cast<MemRefType>(src.getType());
2883 MemRefType resultType =
2884 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2886 getReassociationIndicesAttribute(b, reassociation));
2887 build(b, result, resultType, src, attrs);
2888}
2889
2890LogicalResult CollapseShapeOp::verify() {
2891 MemRefType srcType = getSrcType();
2892 MemRefType resultType = getResultType();
2893
2894 if (srcType.getRank() < resultType.getRank()) {
2895 auto r0 = srcType.getRank();
2896 auto r1 = resultType.getRank();
2897 return emitOpError("has source rank ")
2898 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2899 << r0 << " < " << r1 << ").";
2900 }
2901
2902 // Verify result shape.
2903 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2904 srcType.getShape(), getReassociationIndices(),
2905 /*allowMultipleDynamicDimsPerGroup=*/true)))
2906 return failure();
2907
2908 // Compute expected result type (including layout map).
2909 MemRefType expectedResultType;
2910 if (srcType.getLayout().isIdentity()) {
2911 // If the source is contiguous (i.e., no layout map specified), so is the
2912 // result.
2913 MemRefLayoutAttrInterface layout;
2914 expectedResultType =
2915 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2916 srcType.getMemorySpace());
2917 } else {
2918 // Source may not be fully contiguous. Compute the layout map.
2919 // Note: Dimensions that are collapsed into a single dim are assumed to be
2920 // contiguous.
2921 FailureOr<StridedLayoutAttr> computedLayout =
2922 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2923 if (failed(computedLayout))
2924 return emitOpError(
2925 "invalid source layout map or collapsing non-contiguous dims");
2926 expectedResultType =
2927 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2928 *computedLayout, srcType.getMemorySpace());
2929 }
2930
2931 if (expectedResultType != resultType)
2932 return emitOpError("expected collapsed type to be ")
2933 << expectedResultType << " but found " << resultType;
2934
2935 return success();
2936}
2937
2939 : public OpRewritePattern<CollapseShapeOp> {
2940public:
2941 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2942
2943 LogicalResult matchAndRewrite(CollapseShapeOp op,
2944 PatternRewriter &rewriter) const override {
2945 auto cast = op.getOperand().getDefiningOp<CastOp>();
2946 if (!cast)
2947 return failure();
2948
2949 if (!CastOp::canFoldIntoConsumerOp(cast))
2950 return failure();
2951
2952 Type newResultType = CollapseShapeOp::computeCollapsedType(
2953 llvm::cast<MemRefType>(cast.getOperand().getType()),
2954 op.getReassociationIndices());
2955
2956 if (newResultType == op.getResultType()) {
2957 rewriter.modifyOpInPlace(
2958 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2959 } else {
2960 Value newOp =
2961 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2962 op.getReassociationIndices());
2963 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2964 }
2965 return success();
2966 }
2967};
2968
2969void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2970 MLIRContext *context) {
2971 results.add<
2972 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2973 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2974 memref::DimOp, MemRefType>,
2975 CollapseShapeOpMemRefCastFolder>(context);
2976}
2977
2978OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2980 adaptor.getOperands());
2981}
2982
2983OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2985 adaptor.getOperands());
2986}
2987
2988FailureOr<std::optional<SmallVector<Value>>>
2989CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2990 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2991}
2992
2993//===----------------------------------------------------------------------===//
2994// ReshapeOp
2995//===----------------------------------------------------------------------===//
2996
2997void ReshapeOp::getAsmResultNames(
2998 function_ref<void(Value, StringRef)> setNameFn) {
2999 setNameFn(getResult(), "reshape");
3000}
3001
3002LogicalResult ReshapeOp::verify() {
3003 Type operandType = getSource().getType();
3004 Type resultType = getResult().getType();
3005
3006 Type operandElementType =
3007 llvm::cast<ShapedType>(operandType).getElementType();
3008 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
3009 if (operandElementType != resultElementType)
3010 return emitOpError("element types of source and destination memref "
3011 "types should be the same");
3012
3013 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
3014 if (!operandMemRefType.getLayout().isIdentity())
3015 return emitOpError("source memref type should have identity affine map");
3016
3017 int64_t shapeSize =
3018 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
3019 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
3020 if (resultMemRefType) {
3021 if (!resultMemRefType.getLayout().isIdentity())
3022 return emitOpError("result memref type should have identity affine map");
3023 if (shapeSize == ShapedType::kDynamic)
3024 return emitOpError("cannot use shape operand with dynamic length to "
3025 "reshape to statically-ranked memref type");
3026 if (shapeSize != resultMemRefType.getRank())
3027 return emitOpError(
3028 "length of shape operand differs from the result's memref rank");
3029 }
3030 return success();
3031}
3032
3033FailureOr<std::optional<SmallVector<Value>>>
3034ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
3035 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3036}
3037
3038//===----------------------------------------------------------------------===//
3039// StoreOp
3040//===----------------------------------------------------------------------===//
3041
3042LogicalResult StoreOp::fold(FoldAdaptor adaptor,
3043 SmallVectorImpl<OpFoldResult> &results) {
3044 /// store(memrefcast) -> store
3045 return foldMemRefCast(*this, getValueToStore());
3046}
3047
3048TypedValue<MemRefType> StoreOp::getAccessedMemref() { return getMemref(); }
3049
3050std::optional<SmallVector<Value>>
3051StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
3052 ValueRange newIndices) {
3053 rewriter.modifyOpInPlace(*this, [&]() {
3054 getMemrefMutable().assign(newMemref);
3055 getIndicesMutable().assign(newIndices);
3056 });
3057 return std::nullopt;
3058}
3059
3060FailureOr<std::optional<SmallVector<Value>>>
3061StoreOp::bubbleDownCasts(OpBuilder &builder) {
3063 ValueRange());
3064}
3065
3066//===----------------------------------------------------------------------===//
3067// SubViewOp
3068//===----------------------------------------------------------------------===//
3069
3070void SubViewOp::getAsmResultNames(
3071 function_ref<void(Value, StringRef)> setNameFn) {
3072 setNameFn(getResult(), "subview");
3073}
3074
3075/// A subview result type can be fully inferred from the source type and the
3076/// static representation of offsets, sizes and strides. Special sentinels
3077/// encode the dynamic case.
3078MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3079 ArrayRef<int64_t> staticOffsets,
3080 ArrayRef<int64_t> staticSizes,
3081 ArrayRef<int64_t> staticStrides) {
3082 unsigned rank = sourceMemRefType.getRank();
3083 (void)rank;
3084 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
3085 assert(staticSizes.size() == rank && "staticSizes length mismatch");
3086 assert(staticStrides.size() == rank && "staticStrides length mismatch");
3087
3088 // Extract source offset and strides.
3089 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3090
3091 // Compute target offset whose value is:
3092 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
3093 int64_t targetOffset = sourceOffset;
3094 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
3095 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3096 targetOffset = (SaturatedInteger::wrap(targetOffset) +
3097 SaturatedInteger::wrap(staticOffset) *
3098 SaturatedInteger::wrap(sourceStride))
3099 .asInteger();
3100 }
3101
3102 // Compute target stride whose value is:
3103 // `sourceStrides_i * staticStrides_i`.
3104 SmallVector<int64_t, 4> targetStrides;
3105 targetStrides.reserve(staticOffsets.size());
3106 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
3107 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3108 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
3109 SaturatedInteger::wrap(staticStride))
3110 .asInteger());
3111 }
3112
3113 // The type is now known.
3114 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3115 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3116 targetOffset, targetStrides),
3117 sourceMemRefType.getMemorySpace());
3118}
3119
3120MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3121 ArrayRef<OpFoldResult> offsets,
3122 ArrayRef<OpFoldResult> sizes,
3123 ArrayRef<OpFoldResult> strides) {
3124 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3125 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3126 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3127 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3128 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3129 if (!hasValidSizesOffsets(staticOffsets))
3130 return {};
3131 if (!hasValidSizesOffsets(staticSizes))
3132 return {};
3133 if (!hasValidStrides(staticStrides))
3134 return {};
3135 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3136 staticSizes, staticStrides);
3137}
3138
3139MemRefType SubViewOp::inferRankReducedResultType(
3140 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3141 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3142 ArrayRef<int64_t> strides) {
3143 MemRefType inferredType =
3144 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3145 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
3146 "expected ");
3147 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
3148 return inferredType;
3149
3150 // Compute which dimensions are dropped.
3151 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3152 computeRankReductionMask(inferredType.getShape(), resultShape);
3153 assert(dimsToProject.has_value() && "invalid rank reduction");
3154
3155 // Compute the layout and result type.
3156 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3157 SmallVector<int64_t> rankReducedStrides;
3158 rankReducedStrides.reserve(resultShape.size());
3159 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3160 if (!dimsToProject->contains(idx))
3161 rankReducedStrides.push_back(value);
3162 }
3163 return MemRefType::get(resultShape, inferredType.getElementType(),
3164 StridedLayoutAttr::get(inferredLayout.getContext(),
3165 inferredLayout.getOffset(),
3166 rankReducedStrides),
3167 inferredType.getMemorySpace());
3168}
3169
3170MemRefType SubViewOp::inferRankReducedResultType(
3171 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3172 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3173 ArrayRef<OpFoldResult> strides) {
3174 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3175 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3176 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3177 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3178 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3179 return SubViewOp::inferRankReducedResultType(
3180 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3181 staticStrides);
3182}
3183
3184// Build a SubViewOp with mixed static and dynamic entries and custom result
3185// type. If the type passed is nullptr, it is inferred.
3186void SubViewOp::build(OpBuilder &b, OperationState &result,
3187 MemRefType resultType, Value source,
3188 ArrayRef<OpFoldResult> offsets,
3189 ArrayRef<OpFoldResult> sizes,
3190 ArrayRef<OpFoldResult> strides,
3191 ArrayRef<NamedAttribute> attrs) {
3192 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3193 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3194 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3195 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3196 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3197 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
3198 // Structuring implementation this way avoids duplication between builders.
3199 if (!resultType) {
3200 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3201 staticSizes, staticStrides);
3202 }
3203 result.addAttributes(attrs);
3204 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
3205 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3206 b.getDenseI64ArrayAttr(staticSizes),
3207 b.getDenseI64ArrayAttr(staticStrides));
3208}
3209
3210// Build a SubViewOp with mixed static and dynamic entries and inferred result
3211// type.
3212void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3213 ArrayRef<OpFoldResult> offsets,
3214 ArrayRef<OpFoldResult> sizes,
3215 ArrayRef<OpFoldResult> strides,
3216 ArrayRef<NamedAttribute> attrs) {
3217 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3218}
3219
3220// Build a SubViewOp with static entries and inferred result type.
3221void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3222 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3223 ArrayRef<int64_t> strides,
3224 ArrayRef<NamedAttribute> attrs) {
3225 SmallVector<OpFoldResult> offsetValues =
3226 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3227 return b.getI64IntegerAttr(v);
3228 });
3229 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3230 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3231 SmallVector<OpFoldResult> strideValues =
3232 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3233 return b.getI64IntegerAttr(v);
3234 });
3235 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
3236}
3237
3238// Build a SubViewOp with dynamic entries and custom result type. If the
3239// type passed is nullptr, it is inferred.
3240void SubViewOp::build(OpBuilder &b, OperationState &result,
3241 MemRefType resultType, Value source,
3242 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3243 ArrayRef<int64_t> strides,
3244 ArrayRef<NamedAttribute> attrs) {
3245 SmallVector<OpFoldResult> offsetValues =
3246 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3247 return b.getI64IntegerAttr(v);
3248 });
3249 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3250 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3251 SmallVector<OpFoldResult> strideValues =
3252 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3253 return b.getI64IntegerAttr(v);
3254 });
3255 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3256 attrs);
3257}
3258
3259// Build a SubViewOp with dynamic entries and custom result type. If the type
3260// passed is nullptr, it is inferred.
3261void SubViewOp::build(OpBuilder &b, OperationState &result,
3262 MemRefType resultType, Value source, ValueRange offsets,
3263 ValueRange sizes, ValueRange strides,
3264 ArrayRef<NamedAttribute> attrs) {
3265 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3266 offsets, [](Value v) -> OpFoldResult { return v; });
3267 SmallVector<OpFoldResult> sizeValues =
3268 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3269 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3270 strides, [](Value v) -> OpFoldResult { return v; });
3271 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3272}
3273
3274// Build a SubViewOp with dynamic entries and inferred result type.
3275void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3276 ValueRange offsets, ValueRange sizes, ValueRange strides,
3277 ArrayRef<NamedAttribute> attrs) {
3278 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3279}
3280
3281/// For ViewLikeOpInterface.
3282Value SubViewOp::getViewSource() { return getSource(); }
3283
3284/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3285/// static value).
3286static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3287 int64_t t1Offset, t2Offset;
3288 SmallVector<int64_t> t1Strides, t2Strides;
3289 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3290 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3291 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3292}
3293
3294/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3295/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3296/// marked as dropped in `droppedDims`.
3297static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3298 const llvm::SmallBitVector &droppedDims) {
3299 assert(size_t(t1.getRank()) == droppedDims.size() &&
3300 "incorrect number of bits");
3301 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3302 "incorrect number of dropped dims");
3303 int64_t t1Offset, t2Offset;
3304 SmallVector<int64_t> t1Strides, t2Strides;
3305 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3306 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3307 if (failed(res1) || failed(res2))
3308 return false;
3309 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3310 if (droppedDims[i])
3311 continue;
3312 if (t1Strides[i] != t2Strides[j])
3313 return false;
3314 ++j;
3315 }
3316 return true;
3317}
3318
3320 SubViewOp op, Type expectedType) {
3321 auto memrefType = llvm::cast<ShapedType>(expectedType);
3322 switch (result) {
3324 return success();
3326 return op->emitError("expected result rank to be smaller or equal to ")
3327 << "the source rank, but got " << op.getType();
3329 return op->emitError("expected result type to be ")
3330 << expectedType
3331 << " or a rank-reduced version. (mismatch of result sizes), but got "
3332 << op.getType();
3334 return op->emitError("expected result element type to be ")
3335 << memrefType.getElementType() << ", but got " << op.getType();
3337 return op->emitError(
3338 "expected result and source memory spaces to match, but got ")
3339 << op.getType();
3341 return op->emitError("expected result type to be ")
3342 << expectedType
3343 << " or a rank-reduced version. (mismatch of result layout), but "
3344 "got "
3345 << op.getType();
3346 }
3347 llvm_unreachable("unexpected subview verification result");
3348}
3349
3350/// Verifier for SubViewOp.
3351LogicalResult SubViewOp::verify() {
3352 MemRefType baseType = getSourceType();
3353 MemRefType subViewType = getType();
3354 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3355 ArrayRef<int64_t> staticSizes = getStaticSizes();
3356 ArrayRef<int64_t> staticStrides = getStaticStrides();
3357
3358 // The base memref and the view memref should be in the same memory space.
3359 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3360 return emitError("different memory spaces specified for base memref "
3361 "type ")
3362 << baseType << " and subview memref type " << subViewType;
3363
3364 // Verify that the base memref type has a strided layout map.
3365 if (!baseType.isStrided())
3366 return emitError("base type ") << baseType << " is not strided";
3367
3368 // Compute the expected result type, assuming that there are no rank
3369 // reductions.
3370 MemRefType expectedType = SubViewOp::inferResultType(
3371 baseType, staticOffsets, staticSizes, staticStrides);
3372
3373 // Verify all properties of a shaped type: rank, element type and dimension
3374 // sizes. This takes into account potential rank reductions.
3375 auto shapedTypeVerification = isRankReducedType(
3376 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3377 if (shapedTypeVerification != SliceVerificationResult::Success)
3378 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3379
3380 // Make sure that the memory space did not change.
3381 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3383 *this, expectedType);
3384
3385 // Verify the offset of the layout map.
3386 if (!haveCompatibleOffsets(expectedType, subViewType))
3388 *this, expectedType);
3389
3390 // The only thing that's left to verify now are the strides. First, compute
3391 // the unused dimensions due to rank reductions. We have to look at sizes and
3392 // strides to decide which dimensions were dropped. This function also
3393 // partially verifies strides in case of rank reductions.
3394 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3395 getMixedSizes());
3396 if (failed(unusedDims))
3398 *this, expectedType);
3399
3400 // Strides must match.
3401 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3403 *this, expectedType);
3404
3405 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3406 // to the base memref.
3407 SliceBoundsVerificationResult boundsResult =
3408 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3409 staticStrides, /*generateErrorMessage=*/true);
3410 if (!boundsResult.isValid)
3411 return getOperation()->emitError(boundsResult.errorMessage);
3412
3413 return success();
3414}
3415
3417 return os << "range " << range.offset << ":" << range.size << ":"
3418 << range.stride;
3419}
3420
3421/// Return the list of Range (i.e. offset, size, stride). Each Range
3422/// entry contains either the dynamic value or a ConstantIndexOp constructed
3423/// with `b` at location `loc`.
3424SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3425 OpBuilder &b, Location loc) {
3426 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3427 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3428 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3430 unsigned rank = ranks[0];
3431 res.reserve(rank);
3432 for (unsigned idx = 0; idx < rank; ++idx) {
3433 Value offset =
3434 op.isDynamicOffset(idx)
3435 ? op.getDynamicOffset(idx)
3436 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3437 Value size =
3438 op.isDynamicSize(idx)
3439 ? op.getDynamicSize(idx)
3440 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3441 Value stride =
3442 op.isDynamicStride(idx)
3443 ? op.getDynamicStride(idx)
3444 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3445 res.emplace_back(Range{offset, size, stride});
3446 }
3447 return res;
3448}
3449
3450/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3451/// to deduce the result type for the given `sourceType`. Additionally, reduce
3452/// the rank of the inferred result type if `currentResultType` is lower rank
3453/// than `currentSourceType`. Use this signature if `sourceType` is updated
3454/// together with the result type. In this case, it is important to compute
3455/// the dropped dimensions using `currentSourceType` whose strides align with
3456/// `currentResultType`.
3458 MemRefType currentResultType, MemRefType currentSourceType,
3459 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3460 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3461 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3462 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3463 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3464 currentSourceType, currentResultType, mixedSizes);
3465 if (failed(unusedDims))
3466 return nullptr;
3467
3468 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3469 SmallVector<int64_t> shape, strides;
3470 unsigned numDimsAfterReduction =
3471 nonRankReducedType.getRank() - unusedDims->count();
3472 shape.reserve(numDimsAfterReduction);
3473 strides.reserve(numDimsAfterReduction);
3474 for (const auto &[idx, size, stride] :
3475 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3476 nonRankReducedType.getShape(), layout.getStrides())) {
3477 if (unusedDims->test(idx))
3478 continue;
3479 shape.push_back(size);
3480 strides.push_back(stride);
3481 }
3482
3483 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3484 StridedLayoutAttr::get(sourceType.getContext(),
3485 layout.getOffset(), strides),
3486 nonRankReducedType.getMemorySpace());
3487}
3488
3490 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3491 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3492 unsigned rank = memrefType.getRank();
3493 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3495 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3496 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3497 targetShape, memrefType, offsets, sizes, strides);
3498 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3499 sizes, strides);
3500}
3501
3502FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3503 Value value,
3504 ArrayRef<int64_t> desiredShape) {
3505 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3506 assert(sourceMemrefType && "not a ranked memref type");
3507 auto sourceShape = sourceMemrefType.getShape();
3508 if (sourceShape.equals(desiredShape))
3509 return value;
3510 auto maybeRankReductionMask =
3511 mlir::computeRankReductionMask(sourceShape, desiredShape);
3512 if (!maybeRankReductionMask)
3513 return failure();
3514 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3515}
3516
3517/// Helper method to check if a `subview` operation is trivially a no-op. This
3518/// is the case if the all offsets are zero, all strides are 1, and the source
3519/// shape is same as the size of the subview. In such cases, the subview can
3520/// be folded into its source.
3521static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3522 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3523 return false;
3524
3525 auto mixedOffsets = subViewOp.getMixedOffsets();
3526 auto mixedSizes = subViewOp.getMixedSizes();
3527 auto mixedStrides = subViewOp.getMixedStrides();
3528
3529 // Check offsets are zero.
3530 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3531 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3532 return !intValue || intValue.value() != 0;
3533 }))
3534 return false;
3535
3536 // Check strides are one.
3537 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3538 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3539 return !intValue || intValue.value() != 1;
3540 }))
3541 return false;
3542
3543 // Check all size values are static and matches the (static) source shape.
3544 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3545 for (const auto &size : llvm::enumerate(mixedSizes)) {
3546 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3547 if (!intValue || *intValue != sourceShape[size.index()])
3548 return false;
3549 }
3550 // All conditions met. The `SubViewOp` is foldable as a no-op.
3551 return true;
3552}
3553
3554namespace {
3555/// Pattern to rewrite a subview op with MemRefCast arguments.
3556/// This essentially pushes memref.cast past its consuming subview when
3557/// `canFoldIntoConsumerOp` is true.
3558///
3559/// Example:
3560/// ```
3561/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3562/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3563/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3564/// ```
3565/// is rewritten into:
3566/// ```
3567/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3568/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3569/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3570/// ```
3571class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3572public:
3573 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3574
3575 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3576 PatternRewriter &rewriter) const override {
3577 // Any constant operand, just return to let SubViewOpConstantFolder kick
3578 // in.
3579 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3580 return matchPattern(operand, matchConstantIndex());
3581 }))
3582 return failure();
3583
3584 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3585 if (!castOp)
3586 return failure();
3587
3588 if (!CastOp::canFoldIntoConsumerOp(castOp))
3589 return failure();
3590
3591 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3592 // the MemRefCastOp source operand type to infer the result type and the
3593 // current SubViewOp source operand type to compute the dropped dimensions
3594 // if the operation is rank-reducing.
3595 auto resultType = getCanonicalSubViewResultType(
3596 subViewOp.getType(), subViewOp.getSourceType(),
3597 llvm::cast<MemRefType>(castOp.getSource().getType()),
3598 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3599 subViewOp.getMixedStrides());
3600 if (!resultType)
3601 return failure();
3602
3603 Value newSubView = SubViewOp::create(
3604 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3605 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3606 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3607 subViewOp.getStaticStrides());
3608 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3609 newSubView);
3610 return success();
3611 }
3612};
3613
3614/// Canonicalize subview ops that are no-ops. When the source shape is not
3615/// same as a result shape due to use of `affine_map`.
3616class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3617public:
3618 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3619
3620 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3621 PatternRewriter &rewriter) const override {
3622 if (!isTrivialSubViewOp(subViewOp))
3623 return failure();
3624 if (subViewOp.getSourceType() == subViewOp.getType()) {
3625 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3626 return success();
3627 }
3628 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3629 subViewOp.getSource());
3630 return success();
3631 }
3632};
3633} // namespace
3634
3635/// Return the canonical type of the result of a subview.
3637 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3638 ArrayRef<OpFoldResult> mixedSizes,
3639 ArrayRef<OpFoldResult> mixedStrides) {
3640 // Infer a memref type without taking into account any rank reductions.
3641 MemRefType resTy = SubViewOp::inferResultType(
3642 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3643 if (!resTy)
3644 return {};
3645 MemRefType nonReducedType = resTy;
3646
3647 // Directly return the non-rank reduced type if there are no dropped dims.
3648 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3649 if (droppedDims.none())
3650 return nonReducedType;
3651
3652 // Take the strides and offset from the non-rank reduced type.
3653 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3654
3655 // Drop dims from shape and strides.
3656 SmallVector<int64_t> targetShape;
3657 SmallVector<int64_t> targetStrides;
3658 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3659 if (droppedDims.test(i))
3660 continue;
3661 targetStrides.push_back(nonReducedStrides[i]);
3662 targetShape.push_back(nonReducedType.getDimSize(i));
3663 }
3664
3665 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3666 StridedLayoutAttr::get(nonReducedType.getContext(),
3667 offset, targetStrides),
3668 nonReducedType.getMemorySpace());
3669 }
3670};
3671
3672/// A canonicalizer wrapper to replace SubViewOps.
3674 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3675 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3676 }
3677};
3678
3679void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3680 MLIRContext *context) {
3681 results
3682 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3683 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3684 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3685}
3686
3687OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3688 MemRefType sourceMemrefType = getSource().getType();
3689 MemRefType resultMemrefType = getResult().getType();
3690 auto resultLayout =
3691 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3692
3693 if (resultMemrefType == sourceMemrefType &&
3694 resultMemrefType.hasStaticShape() &&
3695 (!resultLayout || resultLayout.hasStaticLayout())) {
3696 return getViewSource();
3697 }
3698
3699 // Fold subview(subview(x)), where both subviews have the same size and the
3700 // second subview's offsets are all zero. (I.e., the second subview is a
3701 // no-op.)
3702 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3703 auto srcSizes = srcSubview.getMixedSizes();
3704 auto sizes = getMixedSizes();
3705 auto offsets = getMixedOffsets();
3706 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3707 auto strides = getMixedStrides();
3708 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3709 bool allSizesSame = llvm::equal(sizes, srcSizes);
3710 if (allOffsetsZero && allStridesOne && allSizesSame &&
3711 resultMemrefType == sourceMemrefType)
3712 return getViewSource();
3713 }
3714
3715 return {};
3716}
3717
3718FailureOr<std::optional<SmallVector<Value>>>
3719SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3720 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3721}
3722
3723void SubViewOp::inferStridedMetadataRanges(
3724 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3725 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3726 auto isUninitialized =
3727 +[](IntegerValueRange range) { return range.isUninitialized(); };
3728
3729 // Bail early if any of the operands metadata is not ready:
3730 SmallVector<IntegerValueRange> offsetOperands =
3731 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3732 if (llvm::any_of(offsetOperands, isUninitialized))
3733 return;
3734
3735 SmallVector<IntegerValueRange> sizeOperands =
3736 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3737 if (llvm::any_of(sizeOperands, isUninitialized))
3738 return;
3739
3740 SmallVector<IntegerValueRange> stridesOperands =
3741 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3742 if (llvm::any_of(stridesOperands, isUninitialized))
3743 return;
3744
3745 StridedMetadataRange sourceRange =
3746 ranges[getSourceMutable().getOperandNumber()];
3747 if (sourceRange.isUninitialized())
3748 return;
3749
3750 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3751
3752 // Get the dropped dims.
3753 llvm::SmallBitVector droppedDims = getDroppedDims();
3754
3755 // Compute the new offset, strides and sizes.
3756 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3757 SmallVector<ConstantIntRanges> strides, sizes;
3758
3759 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3760 bool dropped = droppedDims.test(i);
3761 // Compute the new offset.
3762 ConstantIntRanges off =
3763 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3764 offset = intrange::inferAdd({offset, off});
3765
3766 // Skip dropped dimensions.
3767 if (dropped)
3768 continue;
3769 // Multiply the strides.
3770 strides.push_back(
3771 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3772 // Get the sizes.
3773 sizes.push_back(sizeOperands[i].getValue());
3774 }
3775
3776 setMetadata(getResult(),
3778 SmallVector<ConstantIntRanges>({std::move(offset)}),
3779 std::move(sizes), std::move(strides)));
3780}
3781
3782//===----------------------------------------------------------------------===//
3783// TransposeOp
3784//===----------------------------------------------------------------------===//
3785
3786void TransposeOp::getAsmResultNames(
3787 function_ref<void(Value, StringRef)> setNameFn) {
3788 setNameFn(getResult(), "transpose");
3789}
3790
3791/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3792static MemRefType inferTransposeResultType(MemRefType memRefType,
3793 AffineMap permutationMap) {
3794 auto originalSizes = memRefType.getShape();
3795 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3796 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3797
3798 // Compute permuted sizes and strides.
3799 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3800 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3801
3802 return MemRefType::Builder(memRefType)
3803 .setShape(sizes)
3804 .setLayout(
3805 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3806}
3807
3808void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3809 AffineMapAttr permutation,
3810 ArrayRef<NamedAttribute> attrs) {
3811 auto permutationMap = permutation.getValue();
3812 assert(permutationMap);
3813
3814 auto memRefType = llvm::cast<MemRefType>(in.getType());
3815 // Compute result type.
3816 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3817
3818 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3819 build(b, result, resultType, in, attrs);
3820}
3821
3822// transpose $in $permutation attr-dict : type($in) `to` type(results)
3823void TransposeOp::print(OpAsmPrinter &p) {
3824 p << " " << getIn() << " " << getPermutation();
3825 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3826 p << " : " << getIn().getType() << " to " << getType();
3827}
3828
3829ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3830 OpAsmParser::UnresolvedOperand in;
3831 AffineMap permutation;
3832 MemRefType srcType, dstType;
3833 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3834 parser.parseOptionalAttrDict(result.attributes) ||
3835 parser.parseColonType(srcType) ||
3836 parser.resolveOperand(in, srcType, result.operands) ||
3837 parser.parseKeywordType("to", dstType) ||
3838 parser.addTypeToList(dstType, result.types))
3839 return failure();
3840
3841 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3842 AffineMapAttr::get(permutation));
3843 return success();
3844}
3845
3846LogicalResult TransposeOp::verify() {
3847 if (!getPermutation().isPermutation())
3848 return emitOpError("expected a permutation map");
3849 if (getPermutation().getNumDims() != getIn().getType().getRank())
3850 return emitOpError("expected a permutation map of same rank as the input");
3851
3852 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3853 auto resultType = llvm::cast<MemRefType>(getType());
3854 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3855 .canonicalizeStridedLayout();
3856
3857 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3858 return emitOpError("result type ")
3859 << resultType
3860 << " is not equivalent to the canonical transposed input type "
3861 << canonicalResultType;
3862 return success();
3863}
3864
3865OpFoldResult TransposeOp::fold(FoldAdaptor) {
3866 // First check for identity permutation, we can fold it away if input and
3867 // result types are identical already.
3868 if (getPermutation().isIdentity() && getType() == getIn().getType())
3869 return getIn();
3870 // Fold two consecutive memref.transpose Ops into one by composing their
3871 // permutation maps.
3872 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3873 AffineMap composedPermutation =
3874 getPermutation().compose(otherTransposeOp.getPermutation());
3875 getInMutable().assign(otherTransposeOp.getIn());
3876 setPermutation(composedPermutation);
3877 return getResult();
3878 }
3879 return {};
3880}
3881
3882FailureOr<std::optional<SmallVector<Value>>>
3883TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3884 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3885}
3886
3887//===----------------------------------------------------------------------===//
3888// ViewOp
3889//===----------------------------------------------------------------------===//
3890
3891void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3892 setNameFn(getResult(), "view");
3893}
3894
3895LogicalResult ViewOp::verify() {
3896 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3897 auto viewType = getType();
3898
3899 // The base memref should have identity layout map (or none).
3900 if (!baseType.getLayout().isIdentity())
3901 return emitError("unsupported map for base memref type ") << baseType;
3902
3903 // The result memref should have identity layout map (or none).
3904 if (!viewType.getLayout().isIdentity())
3905 return emitError("unsupported map for result memref type ") << viewType;
3906
3907 // The base memref and the view memref should be in the same memory space.
3908 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3909 return emitError("different memory spaces specified for base memref "
3910 "type ")
3911 << baseType << " and view memref type " << viewType;
3912
3913 // Verify that we have the correct number of sizes for the result type.
3914 if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
3915 return failure();
3916
3917 return success();
3918}
3919
3920Value ViewOp::getViewSource() { return getSource(); }
3921
3922OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3923 MemRefType sourceMemrefType = getSource().getType();
3924 MemRefType resultMemrefType = getResult().getType();
3925
3926 if (resultMemrefType == sourceMemrefType &&
3927 resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
3928 return getViewSource();
3929
3930 return {};
3931}
3932
3933SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3934 SmallVector<OpFoldResult> result;
3935 unsigned ctr = 0;
3936 Builder b(getContext());
3937 for (int64_t dim : getType().getShape()) {
3938 if (ShapedType::isDynamic(dim)) {
3939 result.push_back(getSizes()[ctr++]);
3940 } else {
3941 result.push_back(b.getIndexAttr(dim));
3942 }
3943 }
3944 return result;
3945}
3946
3947namespace {
3948/// Given a memref type and a range of values that defines its dynamic
3949/// dimension sizes, turn all dynamic sizes that have a constant value into
3950/// static dimension sizes.
3951static MemRefType
3952foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
3953 SmallVectorImpl<Value> &foldedDynamicSizes) {
3954 SmallVector<int64_t> staticShape(type.getShape());
3955 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3956 "incorrect number of dynamic sizes");
3957
3958 // Compute new static and dynamic sizes.
3959 unsigned ctr = 0;
3960 for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3961 if (ShapedType::isStatic(dimSize))
3962 continue;
3963
3964 Value dynamicSize = dynamicSizes[ctr++];
3965 if (auto cst = getConstantIntValue(dynamicSize)) {
3966 // Dynamic size must be non-negative.
3967 if (cst.value() < 0) {
3968 foldedDynamicSizes.push_back(dynamicSize);
3969 continue;
3970 }
3971 staticShape[dim] = cst.value();
3972 } else {
3973 foldedDynamicSizes.push_back(dynamicSize);
3974 }
3975 }
3976
3977 return MemRefType::Builder(type).setShape(staticShape);
3978}
3979
3980/// Change the result type of a `memref.view` by making originally dynamic
3981/// dimensions static when their sizes come from `constant` ops.
3982/// Example:
3983/// ```
3984/// %c5 = arith.constant 5: index
3985/// %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
3986/// ```
3987/// to
3988/// ```
3989/// %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
3990/// ```
3991struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3992 using Base::Base;
3993
3994 LogicalResult matchAndRewrite(ViewOp viewOp,
3995 PatternRewriter &rewriter) const override {
3996 SmallVector<Value> foldedDynamicSizes;
3997 MemRefType resultType = viewOp.getType();
3998 MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
3999 resultType, viewOp.getSizes(), foldedDynamicSizes);
4000
4001 // Stop here if no dynamic size was promoted to static.
4002 if (foldedMemRefType == resultType)
4003 return failure();
4004
4005 // Create new ViewOp.
4006 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
4007 viewOp.getSource(), viewOp.getByteShift(),
4008 foldedDynamicSizes);
4009 // Insert a cast so we have the same type as the old memref type.
4010 rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
4011 return success();
4012 }
4013};
4014
4015/// view(memref.cast(%source)) -> view(%source).
4016struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
4017 using Base::Base;
4018
4019 LogicalResult matchAndRewrite(ViewOp viewOp,
4020 PatternRewriter &rewriter) const override {
4021 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
4022 if (!memrefCastOp)
4023 return failure();
4024
4025 rewriter.replaceOpWithNewOp<ViewOp>(
4026 viewOp, viewOp.getType(), memrefCastOp.getSource(),
4027 viewOp.getByteShift(), viewOp.getSizes());
4028 return success();
4029 }
4030};
4031} // namespace
4032
4033void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
4034 MLIRContext *context) {
4035 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4036}
4037
4038FailureOr<std::optional<SmallVector<Value>>>
4039ViewOp::bubbleDownCasts(OpBuilder &builder) {
4040 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
4041}
4042
4043//===----------------------------------------------------------------------===//
4044// AtomicRMWOp
4045//===----------------------------------------------------------------------===//
4046
4047LogicalResult AtomicRMWOp::verify() {
4048 switch (getKind()) {
4049 case arith::AtomicRMWKind::addf:
4050 case arith::AtomicRMWKind::maximumf:
4051 case arith::AtomicRMWKind::minimumf:
4052 case arith::AtomicRMWKind::mulf:
4053 if (!llvm::isa<FloatType>(getValue().getType()))
4054 return emitOpError() << "with kind '"
4055 << arith::stringifyAtomicRMWKind(getKind())
4056 << "' expects a floating-point type";
4057 break;
4058 case arith::AtomicRMWKind::addi:
4059 case arith::AtomicRMWKind::maxs:
4060 case arith::AtomicRMWKind::maxu:
4061 case arith::AtomicRMWKind::mins:
4062 case arith::AtomicRMWKind::minu:
4063 case arith::AtomicRMWKind::muli:
4064 case arith::AtomicRMWKind::ori:
4065 case arith::AtomicRMWKind::xori:
4066 case arith::AtomicRMWKind::andi:
4067 if (!llvm::isa<IntegerType>(getValue().getType()))
4068 return emitOpError() << "with kind '"
4069 << arith::stringifyAtomicRMWKind(getKind())
4070 << "' expects an integer type";
4071 break;
4072 default:
4073 break;
4074 }
4075 return success();
4076}
4077
4078OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4079 /// atomicrmw(memrefcast) -> atomicrmw
4080 if (succeeded(foldMemRefCast(*this, getValue())))
4081 return getResult();
4082 return OpFoldResult();
4083}
4084
4085FailureOr<std::optional<SmallVector<Value>>>
4086AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4088 getResult());
4089}
4090
4091TypedValue<MemRefType> AtomicRMWOp::getAccessedMemref() { return getMemref(); }
4092
4093std::optional<SmallVector<Value>>
4094AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
4095 ValueRange newIndices) {
4096 rewriter.modifyOpInPlace(*this, [&]() {
4097 getMemrefMutable().assign(newMemref);
4098 getIndicesMutable().assign(newIndices);
4099 });
4100 return std::nullopt;
4101}
4102
4103//===----------------------------------------------------------------------===//
4104// TableGen'd op method definitions
4105//===----------------------------------------------------------------------===//
4106
4107#define GET_OP_CLASSES
4108#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:98
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IndexType getIndexType()
Definition Builders.cpp:55
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.