MLIR 23.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28
29using namespace mlir;
30using namespace mlir::memref;
31
32/// Materialize a single constant operation from a given attribute value with
33/// the desired resultant type.
34Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
35 Attribute value, Type type,
36 Location loc) {
37 return arith::ConstantOp::materialize(builder, value, type, loc);
38}
39
40//===----------------------------------------------------------------------===//
41// Common canonicalization pattern support logic
42//===----------------------------------------------------------------------===//
43
44/// This is a common class used for patterns of the form
45/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
46/// into the root operation directly.
47LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
48 bool folded = false;
49 for (OpOperand &operand : op->getOpOperands()) {
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
54 folded = true;
55 }
56 }
57 return success(folded);
58}
59
60/// Return an unranked/ranked tensor type for the given unranked/ranked memref
61/// type.
63 if (auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(memref.getShape(), memref.getElementType());
65 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(memref.getElementType());
67 return NoneType::get(type.getContext());
68}
69
71 int64_t dim) {
72 auto memrefType = llvm::cast<MemRefType>(value.getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.createOrFold<memref::DimOp>(loc, value, dim);
75
76 return builder.getIndexAttr(memrefType.getDimSize(dim));
77}
78
80 Location loc, Value value) {
81 auto memrefType = llvm::cast<MemRefType>(value.getType());
83 for (int64_t i = 0; i < memrefType.getRank(); ++i)
84 result.push_back(getMixedSize(builder, loc, value, i));
85 return result;
86}
87
88//===----------------------------------------------------------------------===//
89// Utility functions for propagating static information
90//===----------------------------------------------------------------------===//
91
92/// Helper function that sets values[i] to constValues[i] if the latter is a
93/// static value, as indicated by ShapedType::kDynamic.
94///
95/// If constValues[i] is dynamic, tries to extract a constant value from
96/// value[i] to allow for additional folding opportunities. Also convertes all
97/// existing attributes to index attributes. (They may be i64 attributes.)
99 ArrayRef<int64_t> constValues) {
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
103 Builder builder(values[i].getContext());
104 if (ShapedType::isStatic(cstVal)) {
105 // Constant value is known, use it directly.
106 values[i] = builder.getIndexAttr(cstVal);
107 continue;
108 }
109 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
110 // Try to extract a constant or convert an existing to index.
111 values[i] = builder.getIndexAttr(*cst);
112 }
113 }
114}
115
116/// Helper function to retrieve a lossless memory-space cast, and the
117/// corresponding new result memref type.
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
122
123 // Bail if the cast is not lossless.
124 if (!castOp)
125 return {};
126
127 // Transform the source and target type of `castOp` to have the same metadata
128 // as `resultTy`. Bail if not possible.
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
131 if (failed(srcTy))
132 return {};
133
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
136 if (failed(tgtTy))
137 return {};
138
139 // Check if this is a valid memory-space cast.
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return {};
142
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
144}
145
146/// Implementation of `bubbleDownCasts` method for memref operations that
147/// return a single memref result.
148template <typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
151 OpOperand &src) {
152 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
153 // Bail if we cannot cast.
154 if (!castOp)
155 return failure();
156
157 // Create the new operands.
158 SmallVector<Value> operands;
159 llvm::append_range(operands, op->getOperands());
160 operands[src.getOperandNumber()] = castOp.getSourcePtr();
161
162 // Create the new op and results.
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166
167 // Insert a memory-space cast to the original memory space of the op.
168 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 builder, tgtTy,
170 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
171 return std::optional<SmallVector<Value>>(
172 SmallVector<Value>({result.getTargetPtr()}));
173}
174
175//===----------------------------------------------------------------------===//
176// AllocOp / AllocaOp
177//===----------------------------------------------------------------------===//
178
179void AllocOp::getAsmResultNames(
180 function_ref<void(Value, StringRef)> setNameFn) {
181 setNameFn(getResult(), "alloc");
182}
183
184void AllocaOp::getAsmResultNames(
185 function_ref<void(Value, StringRef)> setNameFn) {
186 setNameFn(getResult(), "alloca");
187}
188
189template <typename AllocLikeOp>
190static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
194 if (!memRefType)
195 return op.emitOpError("result must be a memref");
196
197 if (failed(verifyDynamicDimensionCount(op, memRefType, op.getDynamicSizes())))
198 return failure();
199
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError("symbol operand count does not equal memref symbol "
205 "count: expected ")
206 << numSymbols << ", got " << op.getSymbolOperands().size();
207
208 return success();
209}
210
211LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
212
213LogicalResult AllocaOp::verify() {
214 // An alloca op needs to have an ancestor with an allocation scope trait.
215 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
216 return emitOpError(
217 "requires an ancestor op with AutomaticAllocationScope trait");
218
219 return verifyAllocLikeOp(*this);
220}
221
222namespace {
223/// Fold constant dimensions into an alloc like operation.
224template <typename AllocLikeOp>
225struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
227
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter) const override {
230 // Check to see if any dimensions operands are constants. If so, we can
231 // substitute and drop them.
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
233 APInt constSizeArg;
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return false;
236 return constSizeArg.isNonNegative();
237 }))
238 return failure();
239
240 auto memrefType = alloc.getType();
241
242 // Ok, we have one or more constant operands. Collect the non-constant ones
243 // and keep track of the resultant memref type to build.
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
247
248 unsigned dynamicDimPos = 0;
249 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
251 // If this is already static dimension, keep it.
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
254 continue;
255 }
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
257 APInt constSizeArg;
258 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
259 constSizeArg.isNonNegative()) {
260 // Dynamic shape dimension will be folded.
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
262 } else {
263 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
266 }
267 dynamicDimPos++;
268 }
269
270 // Create new memref type (which will have fewer dynamic dimensions).
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
274
275 // Create and insert the alloc op for the new memref.
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
279 // Insert a cast so we have the same type as the old alloc.
280 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
281 return success();
282 }
283};
284
285/// Fold alloc operations with no users or only store and dealloc uses.
286template <typename T>
287struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 using OpRewritePattern<T>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter) const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
296 }))
297 return failure();
298
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
300 rewriter.eraseOp(user);
301
302 rewriter.eraseOp(alloc);
303 return success();
304 }
305};
306} // namespace
307
308void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
311}
312
313void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
314 MLIRContext *context) {
315 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
316 context);
317}
318
319//===----------------------------------------------------------------------===//
320// ReallocOp
321//===----------------------------------------------------------------------===//
322
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
325 MemRefType resultType = getType();
326
327 // The source memref should have identity layout (or none).
328 if (!sourceType.getLayout().isIdentity())
329 return emitError("unsupported layout for source memref type ")
330 << sourceType;
331
332 // The result memref should have identity layout (or none).
333 if (!resultType.getLayout().isIdentity())
334 return emitError("unsupported layout for result memref type ")
335 << resultType;
336
337 // The source memref and the result memref should be in the same memory space.
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError("different memory spaces specified for source memref "
340 "type ")
341 << sourceType << " and result memref type " << resultType;
342
343 // The source memref and the result memref should have the same element type.
344 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
345 "result")))
346 return failure();
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor::parent());
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) {
416 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
417}
418
419/// Given an operation, return whether this op is guaranteed to
420/// allocate an AutomaticAllocationScopeResource
422 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
423 if (!interface)
424 return false;
425 for (auto res : op->getResults()) {
426 if (auto effect =
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
430 return true;
431 }
432 }
433 return false;
434}
435
436/// Given an operation, return whether this op itself could
437/// allocate an AutomaticAllocationScopeResource. Note that
438/// this will not check whether an operation contained within
439/// the op can allocate.
441 // This op itself doesn't create a stack allocation,
442 // the inner allocation should be handled separately.
444 return false;
445 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
446 if (!interface)
447 return true;
448 for (auto res : op->getResults()) {
449 if (auto effect =
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
453 return true;
454 }
455 }
456 return false;
457}
458
459/// Return whether this op is the last non terminating op
460/// in a region. That is to say, it is in a one-block region
461/// and is only followed by a terminator. This prevents
462/// extending the lifetime of allocations.
464 return op->getBlock()->mightHaveTerminator() &&
465 op->getNextNode() == op->getBlock()->getTerminator() &&
467}
468
469/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
470/// or it contains no allocation.
471struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
472 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
473
474 LogicalResult matchAndRewrite(AllocaScopeOp op,
475 PatternRewriter &rewriter) const override {
476 bool hasPotentialAlloca =
477 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
478 if (alloc == op)
479 return WalkResult::advance();
481 return WalkResult::interrupt();
482 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
483 return WalkResult::skip();
484 return WalkResult::advance();
485 }).wasInterrupted();
486
487 // If this contains no potential allocation, it is always legal to
488 // inline. Otherwise, consider two conditions:
489 if (hasPotentialAlloca) {
490 // If the parent isn't an allocation scope, or we are not the last
491 // non-terminator op in the parent, we will extend the lifetime.
492 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return failure();
495 return failure();
496 }
497
498 Block *block = &op.getRegion().front();
499 Operation *terminator = block->getTerminator();
500 ValueRange results = terminator->getOperands();
501 rewriter.inlineBlockBefore(block, op);
502 rewriter.replaceOp(op, results);
503 rewriter.eraseOp(terminator);
504 return success();
505 }
506};
507
508/// Move allocations into an allocation scope, if it is legal to
509/// move them (e.g. their operands are available at the location
510/// the op would be moved to).
511struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
512 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
513
514 LogicalResult matchAndRewrite(AllocaScopeOp op,
515 PatternRewriter &rewriter) const override {
516
517 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
518 return failure();
519
520 Operation *lastParentWithoutScope = op->getParentOp();
521
522 if (!lastParentWithoutScope ||
523 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
524 return failure();
525
526 // Only apply to if this is this last non-terminator
527 // op in the block (lest lifetime be extended) of a one
528 // block region
529 if (!lastNonTerminatorInRegion(op) ||
530 !lastNonTerminatorInRegion(lastParentWithoutScope))
531 return failure();
532
533 while (!lastParentWithoutScope->getParentOp()
535 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
536 if (!lastParentWithoutScope ||
537 !lastNonTerminatorInRegion(lastParentWithoutScope))
538 return failure();
539 }
540 assert(lastParentWithoutScope->getParentOp()
542
543 Region *containingRegion = nullptr;
544 for (auto &r : lastParentWithoutScope->getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion == nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
549 }
550 }
551 assert(containingRegion && "op must be contained in a region");
552
554 op->walk([&](Operation *alloc) {
556 return WalkResult::skip();
557
558 // If any operand is not defined before the location of
559 // lastParentWithoutScope (i.e. where we would hoist to), skip.
560 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
561 return containingRegion->isAncestor(v.getParentRegion());
562 }))
563 return WalkResult::skip();
564 toHoist.push_back(alloc);
565 return WalkResult::advance();
566 });
567
568 if (toHoist.empty())
569 return failure();
570 rewriter.setInsertionPoint(lastParentWithoutScope);
571 for (auto *op : toHoist) {
572 auto *cloned = rewriter.clone(*op);
573 rewriter.replaceOp(op, cloned->getResults());
574 }
575 return success();
576 }
577};
578
579void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
580 MLIRContext *context) {
581 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
582}
583
584//===----------------------------------------------------------------------===//
585// AssumeAlignmentOp
586//===----------------------------------------------------------------------===//
587
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError("alignment must be power of 2");
591 return success();
592}
593
594void AssumeAlignmentOp::getAsmResultNames(
595 function_ref<void(Value, StringRef)> setNameFn) {
596 setNameFn(getResult(), "assume_align");
597}
598
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
601 if (!source)
602 return {};
603 if (source.getAlignment() != getAlignment())
604 return {};
605 return getMemref();
606}
607
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
610 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
611}
612
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
614 int resultIndex,
615 int dim) {
616 assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
618}
619
620//===----------------------------------------------------------------------===//
621// DistinctObjectsOp
622//===----------------------------------------------------------------------===//
623
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError("operand types and result types must match");
627
628 if (getOperandTypes().empty())
629 return emitOpError("expected at least one operand");
630
631 return success();
632}
633
634LogicalResult DistinctObjectsOp::inferReturnTypes(
635 MLIRContext * /*context*/, std::optional<Location> /*location*/,
636 ValueRange operands, DictionaryAttr /*attributes*/,
637 OpaqueProperties /*properties*/, RegionRange /*regions*/,
638 SmallVectorImpl<Type> &inferredReturnTypes) {
639 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// CastOp
645//===----------------------------------------------------------------------===//
646
647void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
648 setNameFn(getResult(), "cast");
649}
650
651/// Determines whether MemRef_CastOp casts to a more dynamic version of the
652/// source memref. This is useful to fold a memref.cast into a consuming op
653/// and implement canonicalization patterns for ops in different dialects that
654/// may consume the results of memref.cast operations. Such foldable memref.cast
655/// operations are typically inserted as `view` and `subview` ops are
656/// canonicalized, to preserve the type compatibility of their uses.
657///
658/// Returns true when all conditions are met:
659/// 1. source and result are ranked memrefs with strided semantics and same
660/// element type and rank.
661/// 2. each of the source's size, offset or stride has more static information
662/// than the corresponding result's size, offset or stride.
663///
664/// Example 1:
665/// ```mlir
666/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
667/// %2 = consumer %1 ... : memref<?x?xf32> ...
668/// ```
669///
670/// may fold into:
671///
672/// ```mlir
673/// %2 = consumer %0 ... : memref<8x16xf32> ...
674/// ```
675///
676/// Example 2:
677/// ```
678/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
679/// to memref<?x?xf32>
680/// consumer %1 : memref<?x?xf32> ...
681/// ```
682///
683/// may fold into:
684///
685/// ```
686/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
687/// ```
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692
693 // Requires ranked MemRefType.
694 if (!sourceType || !resultType)
695 return false;
696
697 // Requires same elemental type.
698 if (sourceType.getElementType() != resultType.getElementType())
699 return false;
700
701 // Requires same rank.
702 if (sourceType.getRank() != resultType.getRank())
703 return false;
704
705 // Only fold casts between strided memref forms.
706 int64_t sourceOffset, resultOffset;
707 SmallVector<int64_t, 4> sourceStrides, resultStrides;
708 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
710 return false;
711
712 // If cast is towards more static sizes along any dimension, don't fold.
713 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
715 if (ss != st)
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
717 return false;
718 }
719
720 // If cast is towards more static offset along any dimension, don't fold.
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
724 return false;
725
726 // If cast is towards more static strides along any dimension, don't fold.
727 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
729 if (ss != st)
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
731 return false;
732 }
733
734 return true;
735}
736
737bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
738 if (inputs.size() != 1 || outputs.size() != 1)
739 return false;
740 Type a = inputs.front(), b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(b);
743
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
746
747 if (aT && bT) {
748 if (aT.getElementType() != bT.getElementType())
749 return false;
750 if (aT.getLayout() != bT.getLayout()) {
751 int64_t aOffset, bOffset;
752 SmallVector<int64_t, 4> aStrides, bStrides;
753 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
756 return false;
757
758 // Strides along a dimension/offset are compatible if the value in the
759 // source memref is static and the value in the target memref is the
760 // same. They are also compatible if either one is dynamic (see
761 // description of MemRefCastOp for details).
762 // Note that for dimensions of size 1, the stride can differ.
763 auto checkCompatible = [](int64_t a, int64_t b) {
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
765 };
766 if (!checkCompatible(aOffset, bOffset))
767 return false;
768 for (const auto &[index, aStride] : enumerate(aStrides)) {
769 if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
770 continue;
771 if (!checkCompatible(aStride, bStrides[index]))
772 return false;
773 }
774 }
775 if (aT.getMemorySpace() != bT.getMemorySpace())
776 return false;
777
778 // They must have the same rank, and any specified dimensions must match.
779 if (aT.getRank() != bT.getRank())
780 return false;
781
782 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
785 aDim != bDim)
786 return false;
787 }
788 return true;
789 } else {
790 if (!aT && !uaT)
791 return false;
792 if (!bT && !ubT)
793 return false;
794 // Unranked to unranked casting is unsupported
795 if (uaT && ubT)
796 return false;
797
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
801 return false;
802
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
806 }
807
808 return false;
809}
810
811OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
812 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
813}
814
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(OpBuilder &builder) {
817 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
818}
819
820//===----------------------------------------------------------------------===//
821// CopyOp
822//===----------------------------------------------------------------------===//
823
824namespace {
825
826/// Fold memref.copy(%x, %x).
827struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
829
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter) const override {
832 if (copyOp.getSource() != copyOp.getTarget())
833 return failure();
834
835 rewriter.eraseOp(copyOp);
836 return success();
837 }
838};
839
840struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
842
843 static bool isEmptyMemRef(BaseMemRefType type) {
844 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
845 }
846
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter) const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
851 rewriter.eraseOp(copyOp);
852 return success();
853 }
854
855 return failure();
856 }
857};
858} // namespace
859
860void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
861 MLIRContext *context) {
862 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
863}
864
865/// If the source/target of a CopyOp is a CastOp that does not modify the shape
866/// and element type, the cast can be skipped. Such CastOps only cast the layout
867/// of the type.
868static LogicalResult foldCopyOfCast(CopyOp op) {
869 for (OpOperand &operand : op->getOpOperands()) {
870 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
873 return success();
874 }
875 }
876 return failure();
877}
878
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
881
882 /// copy(memrefcast) -> copy
883 return foldCopyOfCast(*this);
884}
885
886//===----------------------------------------------------------------------===//
887// DeallocOp
888//===----------------------------------------------------------------------===//
889
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
892 /// dealloc(memrefcast) -> dealloc
893 return foldMemRefCast(*this);
894}
895
896//===----------------------------------------------------------------------===//
897// DimOp
898//===----------------------------------------------------------------------===//
899
900void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(), "dim");
902}
903
904void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
905 int64_t index) {
906 auto loc = result.location;
907 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
908 build(builder, result, source, indexValue);
909}
910
911std::optional<int64_t> DimOp::getConstantIndex() {
913}
914
915Speculation::Speculatability DimOp::getSpeculatability() {
916 auto constantIndex = getConstantIndex();
917 if (!constantIndex)
919
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
921 if (!rankedSourceType)
923
924 if (rankedSourceType.getRank() <= constantIndex)
926
928}
929
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
931 SetIntLatticeFn setResultRange) {
932 setResultRange(getResult(),
933 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
934}
935
936/// Return a map with key being elements in `vals` and data being number of
937/// occurences of it. Use std::map, since the `vals` here are strides and the
938/// dynamic stride value is the same as the tombstone value for
939/// `DenseMap<int64_t>`.
940static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
941 std::map<int64_t, unsigned> numOccurences;
942 for (auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
945}
946
947/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
948/// to be a subset of `originalType` with some `1` entries erased, return the
949/// set of indices that specifies which of the entries of `originalShape` are
950/// dropped to obtain `reducedShape`.
951/// This accounts for cases where there are multiple unit-dims, but only a
952/// subset of those are dropped. For MemRefTypes these can be disambiguated
953/// using the strides. If a dimension is dropped the stride must be dropped too.
954static FailureOr<llvm::SmallBitVector>
955computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
957 llvm::SmallBitVector unusedDims(originalType.getRank());
958 if (originalType.getRank() == reducedType.getRank())
959 return unusedDims;
960
961 for (const auto &dim : llvm::enumerate(sizes))
962 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
963 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
964 unusedDims.set(dim.index());
965
966 // Early exit for the case where the number of unused dims matches the number
967 // of ranks reduced.
968 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
969 originalType.getRank())
970 return unusedDims;
971
972 SmallVector<int64_t> originalStrides, candidateStrides;
973 int64_t originalOffset, candidateOffset;
974 if (failed(
975 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
976 failed(
977 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
978 return failure();
979
980 // For memrefs, a dimension is truly dropped if its corresponding stride is
981 // also dropped. This is particularly important when more than one of the dims
982 // is 1. Track the number of occurences of the strides in the original type
983 // and the candidate type. For each unused dim that stride should not be
984 // present in the candidate type. Note that there could be multiple dimensions
985 // that have the same size. We dont need to exactly figure out which dim
986 // corresponds to which stride, we just need to verify that the number of
987 // reptitions of a stride in the original + number of unused dims with that
988 // stride == number of repititions of a stride in the candidate.
989 std::map<int64_t, unsigned> currUnaccountedStrides =
990 getNumOccurences(originalStrides);
991 std::map<int64_t, unsigned> candidateStridesNumOccurences =
992 getNumOccurences(candidateStrides);
993 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
994 if (!unusedDims.test(dim))
995 continue;
996 int64_t originalStride = originalStrides[dim];
997 if (currUnaccountedStrides[originalStride] >
998 candidateStridesNumOccurences[originalStride]) {
999 // This dim can be treated as dropped.
1000 currUnaccountedStrides[originalStride]--;
1001 continue;
1002 }
1003 if (currUnaccountedStrides[originalStride] ==
1004 candidateStridesNumOccurences[originalStride]) {
1005 // The stride for this is not dropped. Keep as is.
1006 unusedDims.reset(dim);
1007 continue;
1008 }
1009 if (currUnaccountedStrides[originalStride] <
1010 candidateStridesNumOccurences[originalStride]) {
1011 // This should never happen. Cant have a stride in the reduced rank type
1012 // that wasnt in the original one.
1013 return failure();
1014 }
1015 }
1016
1017 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1018 originalType.getRank())
1019 return failure();
1020 return unusedDims;
1021}
1022
1023llvm::SmallBitVector SubViewOp::getDroppedDims() {
1024 MemRefType sourceType = getSourceType();
1025 MemRefType resultType = getType();
1026 FailureOr<llvm::SmallBitVector> unusedDims =
1027 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1028 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1029 return *unusedDims;
1030}
1031
1032OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1033 // All forms of folding require a known index.
1034 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1035 if (!index)
1036 return {};
1037
1038 // Folding for unranked types (UnrankedMemRefType) is not supported.
1039 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1040 if (!memrefType)
1041 return {};
1042
1043 // Out of bound indices produce undefined behavior but are still valid IR.
1044 // Don't choke on them.
1045 int64_t indexVal = index.getInt();
1046 if (indexVal < 0 || indexVal >= memrefType.getRank())
1047 return {};
1048
1049 // Fold if the shape extent along the given index is known.
1050 if (!memrefType.isDynamicDim(index.getInt())) {
1051 Builder builder(getContext());
1052 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1053 }
1054
1055 // The size at the given index is now known to be a dynamic size.
1056 unsigned unsignedIndex = index.getValue().getZExtValue();
1057
1058 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1059 Operation *definingOp = getSource().getDefiningOp();
1060
1061 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1062 return *(alloc.getDynamicSizes().begin() +
1063 memrefType.getDynamicDimIndex(unsignedIndex));
1064
1065 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1066 return *(alloca.getDynamicSizes().begin() +
1067 memrefType.getDynamicDimIndex(unsignedIndex));
1068
1069 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1070 return *(view.getDynamicSizes().begin() +
1071 memrefType.getDynamicDimIndex(unsignedIndex));
1072
1073 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1074 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1075 unsigned resultIndex = 0;
1076 unsigned sourceRank = subview.getSourceType().getRank();
1077 unsigned sourceIndex = 0;
1078 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1079 if (unusedDims.test(i))
1080 continue;
1081 if (resultIndex == unsignedIndex) {
1082 sourceIndex = i;
1083 break;
1084 }
1085 resultIndex++;
1086 }
1087 assert(subview.isDynamicSize(sourceIndex) &&
1088 "expected dynamic subview size");
1089 return subview.getDynamicSize(sourceIndex);
1090 }
1091
1092 // dim(memrefcast) -> dim
1093 if (succeeded(foldMemRefCast(*this)))
1094 return getResult();
1095
1096 return {};
1097}
1098
1099namespace {
1100/// Fold dim of a memref reshape operation to a load into the reshape's shape
1101/// operand.
1102struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1103 using OpRewritePattern<DimOp>::OpRewritePattern;
1104
1105 LogicalResult matchAndRewrite(DimOp dim,
1106 PatternRewriter &rewriter) const override {
1107 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1108
1109 if (!reshape)
1110 return rewriter.notifyMatchFailure(
1111 dim, "Dim op is not defined by a reshape op.");
1112
1113 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1114 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1115 // cheaply check that either of the following conditions hold:
1116 // 1. dim.getIndex() is defined in the same block as reshape but before
1117 // reshape.
1118 // 2. dim.getIndex() is defined in a parent block of
1119 // reshape.
1120
1121 // Check condition 1
1122 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1123 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1124 if (reshape->isBeforeInBlock(definingOp)) {
1125 return rewriter.notifyMatchFailure(
1126 dim,
1127 "dim.getIndex is not defined before reshape in the same block.");
1128 }
1129 } // else dim.getIndex is a block argument to reshape->getBlock and
1130 // dominates reshape
1131 } // Check condition 2
1132 else if (dim->getBlock() != reshape->getBlock() &&
1133 !dim.getIndex().getParentRegion()->isProperAncestor(
1134 reshape->getParentRegion())) {
1135 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1136 // already know dim.getIndex() dominates reshape without calling
1137 // `isProperAncestor`
1138 return rewriter.notifyMatchFailure(
1139 dim, "dim.getIndex does not dominate reshape.");
1140 }
1141
1142 // Place the load directly after the reshape to ensure that the shape memref
1143 // was not mutated.
1144 rewriter.setInsertionPointAfter(reshape);
1145 Location loc = dim.getLoc();
1146 Value load =
1147 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1148 if (load.getType() != dim.getType())
1149 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1150 rewriter.replaceOp(dim, load);
1151 return success();
1152 }
1153};
1154
1155} // namespace
1156
1157void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1158 MLIRContext *context) {
1159 results.add<DimOfMemRefReshape>(context);
1160}
1161
1162// ---------------------------------------------------------------------------
1163// DmaStartOp
1164// ---------------------------------------------------------------------------
1165
1166void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1167 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1168 ValueRange destIndices, Value numElements,
1169 Value tagMemRef, ValueRange tagIndices, Value stride,
1170 Value elementsPerStride) {
1171 result.addOperands(srcMemRef);
1172 result.addOperands(srcIndices);
1173 result.addOperands(destMemRef);
1174 result.addOperands(destIndices);
1175 result.addOperands({numElements, tagMemRef});
1176 result.addOperands(tagIndices);
1177 if (stride)
1178 result.addOperands({stride, elementsPerStride});
1179}
1180
1181void DmaStartOp::print(OpAsmPrinter &p) {
1182 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1183 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1184 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1185 if (isStrided())
1186 p << ", " << getStride() << ", " << getNumElementsPerStride();
1187
1188 p.printOptionalAttrDict((*this)->getAttrs());
1189 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1190 << ", " << getTagMemRef().getType();
1191}
1192
1193// Parse DmaStartOp.
1194// Ex:
1195// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1196// %tag[%index], %stride, %num_elt_per_stride :
1197// : memref<3076 x f32, 0>,
1198// memref<1024 x f32, 2>,
1199// memref<1 x i32>
1200//
1201ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1202 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1204 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1206 OpAsmParser::UnresolvedOperand numElementsInfo;
1207 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1210
1212 auto indexType = parser.getBuilder().getIndexType();
1213
1214 // Parse and resolve the following list of operands:
1215 // *) source memref followed by its indices (in square brackets).
1216 // *) destination memref followed by its indices (in square brackets).
1217 // *) dma size in KiB.
1218 if (parser.parseOperand(srcMemRefInfo) ||
1219 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1220 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1221 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1222 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1223 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1224 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1225 return failure();
1226
1227 // Parse optional stride and elements per stride.
1228 if (parser.parseTrailingOperandList(strideInfo))
1229 return failure();
1230
1231 bool isStrided = strideInfo.size() == 2;
1232 if (!strideInfo.empty() && !isStrided) {
1233 return parser.emitError(parser.getNameLoc(),
1234 "expected two stride related operands");
1235 }
1236
1237 if (parser.parseColonTypeList(types))
1238 return failure();
1239 if (types.size() != 3)
1240 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1241
1242 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1243 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1244 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1245 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1246 // size should be an index.
1247 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1248 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1249 // tag indices should be index.
1250 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1251 return failure();
1252
1253 if (isStrided) {
1254 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1255 return failure();
1256 }
1257
1258 return success();
1259}
1260
1261LogicalResult DmaStartOp::verify() {
1262 unsigned numOperands = getNumOperands();
1263
1264 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1265 // the number of elements.
1266 if (numOperands < 4)
1267 return emitOpError("expected at least 4 operands");
1268
1269 // Check types of operands. The order of these calls is important: the later
1270 // calls rely on some type properties to compute the operand position.
1271 // 1. Source memref.
1272 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1273 return emitOpError("expected source to be of memref type");
1274 if (numOperands < getSrcMemRefRank() + 4)
1275 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1276 << " operands";
1277 if (!getSrcIndices().empty() &&
1278 !llvm::all_of(getSrcIndices().getTypes(),
1279 [](Type t) { return t.isIndex(); }))
1280 return emitOpError("expected source indices to be of index type");
1281
1282 // 2. Destination memref.
1283 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1284 return emitOpError("expected destination to be of memref type");
1285 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1286 if (numOperands < numExpectedOperands)
1287 return emitOpError() << "expected at least " << numExpectedOperands
1288 << " operands";
1289 if (!getDstIndices().empty() &&
1290 !llvm::all_of(getDstIndices().getTypes(),
1291 [](Type t) { return t.isIndex(); }))
1292 return emitOpError("expected destination indices to be of index type");
1293
1294 // 3. Number of elements.
1295 if (!getNumElements().getType().isIndex())
1296 return emitOpError("expected num elements to be of index type");
1297
1298 // 4. Tag memref.
1299 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1300 return emitOpError("expected tag to be of memref type");
1301 numExpectedOperands += getTagMemRefRank();
1302 if (numOperands < numExpectedOperands)
1303 return emitOpError() << "expected at least " << numExpectedOperands
1304 << " operands";
1305 if (!getTagIndices().empty() &&
1306 !llvm::all_of(getTagIndices().getTypes(),
1307 [](Type t) { return t.isIndex(); }))
1308 return emitOpError("expected tag indices to be of index type");
1309
1310 // Optional stride-related operands must be either both present or both
1311 // absent.
1312 if (numOperands != numExpectedOperands &&
1313 numOperands != numExpectedOperands + 2)
1314 return emitOpError("incorrect number of operands");
1315
1316 // 5. Strides.
1317 if (isStrided()) {
1318 if (!getStride().getType().isIndex() ||
1319 !getNumElementsPerStride().getType().isIndex())
1320 return emitOpError(
1321 "expected stride and num elements per stride to be of type index");
1322 }
1323
1324 return success();
1325}
1326
1327LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1329 /// dma_start(memrefcast) -> dma_start
1330 return foldMemRefCast(*this);
1331}
1332
1333// ---------------------------------------------------------------------------
1334// DmaWaitOp
1335// ---------------------------------------------------------------------------
1336
1337LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1339 /// dma_wait(memrefcast) -> dma_wait
1340 return foldMemRefCast(*this);
1341}
1342
1343LogicalResult DmaWaitOp::verify() {
1344 // Check that the number of tag indices matches the tagMemRef rank.
1345 unsigned numTagIndices = getTagIndices().size();
1346 unsigned tagMemRefRank = getTagMemRefRank();
1347 if (numTagIndices != tagMemRefRank)
1348 return emitOpError() << "expected tagIndices to have the same number of "
1349 "elements as the tagMemRef rank, expected "
1350 << tagMemRefRank << ", but got " << numTagIndices;
1351 return success();
1352}
1353
1354//===----------------------------------------------------------------------===//
1355// ExtractAlignedPointerAsIndexOp
1356//===----------------------------------------------------------------------===//
1357
1358void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1359 function_ref<void(Value, StringRef)> setNameFn) {
1360 setNameFn(getResult(), "intptr");
1361}
1362
1363//===----------------------------------------------------------------------===//
1364// ExtractStridedMetadataOp
1365//===----------------------------------------------------------------------===//
1366
1367/// The number and type of the results are inferred from the
1368/// shape of the source.
1369LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1371 ExtractStridedMetadataOp::Adaptor adaptor,
1372 SmallVectorImpl<Type> &inferredReturnTypes) {
1373 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1374 if (!sourceType)
1375 return failure();
1376
1377 unsigned sourceRank = sourceType.getRank();
1378 IndexType indexType = IndexType::get(context);
1379 auto memrefType =
1380 MemRefType::get({}, sourceType.getElementType(),
1381 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1382 // Base.
1383 inferredReturnTypes.push_back(memrefType);
1384 // Offset.
1385 inferredReturnTypes.push_back(indexType);
1386 // Sizes and strides.
1387 for (unsigned i = 0; i < sourceRank * 2; ++i)
1388 inferredReturnTypes.push_back(indexType);
1389 return success();
1390}
1391
1392void ExtractStridedMetadataOp::getAsmResultNames(
1393 function_ref<void(Value, StringRef)> setNameFn) {
1394 setNameFn(getBaseBuffer(), "base_buffer");
1395 setNameFn(getOffset(), "offset");
1396 // For multi-result to work properly with pretty names and packed syntax `x:3`
1397 // we can only give a pretty name to the first value in the pack.
1398 if (!getSizes().empty()) {
1399 setNameFn(getSizes().front(), "sizes");
1400 setNameFn(getStrides().front(), "strides");
1401 }
1402}
1403
1404/// Helper function to perform the replacement of all constant uses of `values`
1405/// by a materialized constant extracted from `maybeConstants`.
1406/// `values` and `maybeConstants` are expected to have the same size.
1407template <typename Container>
1408static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1409 Container values,
1410 ArrayRef<OpFoldResult> maybeConstants) {
1411 assert(values.size() == maybeConstants.size() &&
1412 " expected values and maybeConstants of the same size");
1413 bool atLeastOneReplacement = false;
1414 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1415 // Don't materialize a constant if there are no uses: this would indice
1416 // infinite loops in the driver.
1417 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1418 continue;
1419 assert(isa<Attribute>(maybeConstant) &&
1420 "The constified value should be either unchanged (i.e., == result) "
1421 "or a constant");
1423 rewriter, loc,
1424 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1425 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1426 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1427 // yet.
1428 op->replaceUsesOfWith(result, constantVal);
1429 atLeastOneReplacement = true;
1430 }
1431 }
1432 return atLeastOneReplacement;
1433}
1434
1435LogicalResult
1436ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1438 OpBuilder builder(*this);
1439
1440 bool atLeastOneReplacement = replaceConstantUsesOf(
1441 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1442 getConstifiedMixedOffset());
1443 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1444 getConstifiedMixedSizes());
1445 atLeastOneReplacement |= replaceConstantUsesOf(
1446 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1447
1448 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1449 if (auto prev = getSource().getDefiningOp<CastOp>())
1450 if (isa<MemRefType>(prev.getSource().getType())) {
1451 getSourceMutable().assign(prev.getSource());
1452 atLeastOneReplacement = true;
1453 }
1454
1455 return success(atLeastOneReplacement);
1456}
1457
1458SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1459 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1460 constifyIndexValues(values, getSource().getType().getShape());
1461 return values;
1462}
1463
1465ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1466 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1467 SmallVector<int64_t> staticValues;
1468 int64_t unused;
1469 LogicalResult status =
1470 getSource().getType().getStridesAndOffset(staticValues, unused);
1471 (void)status;
1472 assert(succeeded(status) && "could not get strides from type");
1473 constifyIndexValues(values, staticValues);
1474 return values;
1475}
1476
1477OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1478 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1479 SmallVector<OpFoldResult> values(1, offsetOfr);
1480 SmallVector<int64_t> staticValues, unused;
1481 int64_t offset;
1482 LogicalResult status =
1483 getSource().getType().getStridesAndOffset(unused, offset);
1484 (void)status;
1485 assert(succeeded(status) && "could not get offset from type");
1486 staticValues.push_back(offset);
1487 constifyIndexValues(values, staticValues);
1488 return values[0];
1489}
1490
1491//===----------------------------------------------------------------------===//
1492// GenericAtomicRMWOp
1493//===----------------------------------------------------------------------===//
1494
1495void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1496 Value memref, ValueRange ivs) {
1497 OpBuilder::InsertionGuard g(builder);
1498 result.addOperands(memref);
1499 result.addOperands(ivs);
1500
1501 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1502 Type elementType = memrefType.getElementType();
1503 result.addTypes(elementType);
1504
1505 Region *bodyRegion = result.addRegion();
1506 builder.createBlock(bodyRegion);
1507 bodyRegion->addArgument(elementType, memref.getLoc());
1508 }
1509}
1510
1511LogicalResult GenericAtomicRMWOp::verify() {
1512 auto &body = getRegion();
1513 if (body.getNumArguments() != 1)
1514 return emitOpError("expected single number of entry block arguments");
1515
1516 if (getResult().getType() != body.getArgument(0).getType())
1517 return emitOpError("expected block argument of the same type result type");
1518
1519 bool hasSideEffects =
1520 body.walk([&](Operation *nestedOp) {
1521 if (isMemoryEffectFree(nestedOp))
1522 return WalkResult::advance();
1523 nestedOp->emitError(
1524 "body of 'memref.generic_atomic_rmw' should contain "
1525 "only operations with no side effects");
1526 return WalkResult::interrupt();
1527 })
1528 .wasInterrupted();
1529 return hasSideEffects ? failure() : success();
1530}
1531
1532ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1535 Type memrefType;
1537
1538 Type indexType = parser.getBuilder().getIndexType();
1539 if (parser.parseOperand(memref) ||
1541 parser.parseColonType(memrefType) ||
1542 parser.resolveOperand(memref, memrefType, result.operands) ||
1543 parser.resolveOperands(ivs, indexType, result.operands))
1544 return failure();
1545
1546 Region *body = result.addRegion();
1547 if (parser.parseRegion(*body, {}) ||
1548 parser.parseOptionalAttrDict(result.attributes))
1549 return failure();
1550 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1551 return success();
1552}
1553
1554void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1555 p << ' ' << getMemref() << "[" << getIndices()
1556 << "] : " << getMemref().getType() << ' ';
1557 p.printRegion(getRegion());
1558 p.printOptionalAttrDict((*this)->getAttrs());
1559}
1560
1561//===----------------------------------------------------------------------===//
1562// AtomicYieldOp
1563//===----------------------------------------------------------------------===//
1564
1565LogicalResult AtomicYieldOp::verify() {
1566 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1567 Type resultType = getResult().getType();
1568 if (parentType != resultType)
1569 return emitOpError() << "types mismatch between yield op: " << resultType
1570 << " and its parent: " << parentType;
1571 return success();
1572}
1573
1574//===----------------------------------------------------------------------===//
1575// GlobalOp
1576//===----------------------------------------------------------------------===//
1577
1579 TypeAttr type,
1580 Attribute initialValue) {
1581 p << type;
1582 if (!op.isExternal()) {
1583 p << " = ";
1584 if (op.isUninitialized())
1585 p << "uninitialized";
1586 else
1587 p.printAttributeWithoutType(initialValue);
1588 }
1589}
1590
1591static ParseResult
1593 Attribute &initialValue) {
1594 Type type;
1595 if (parser.parseType(type))
1596 return failure();
1597
1598 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1599 if (!memrefType || !memrefType.hasStaticShape())
1600 return parser.emitError(parser.getNameLoc())
1601 << "type should be static shaped memref, but got " << type;
1602 typeAttr = TypeAttr::get(type);
1603
1604 if (parser.parseOptionalEqual())
1605 return success();
1606
1607 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1608 initialValue = UnitAttr::get(parser.getContext());
1609 return success();
1610 }
1611
1612 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1613 if (parser.parseAttribute(initialValue, tensorType))
1614 return failure();
1615 if (!llvm::isa<ElementsAttr>(initialValue))
1616 return parser.emitError(parser.getNameLoc())
1617 << "initial value should be a unit or elements attribute";
1618 return success();
1619}
1620
1621LogicalResult GlobalOp::verify() {
1622 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1623 if (!memrefType || !memrefType.hasStaticShape())
1624 return emitOpError("type should be static shaped memref, but got ")
1625 << getType();
1626
1627 // Verify that the initial value, if present, is either a unit attribute or
1628 // an elements attribute.
1629 if (getInitialValue().has_value()) {
1630 Attribute initValue = getInitialValue().value();
1631 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1632 return emitOpError("initial value should be a unit or elements "
1633 "attribute, but got ")
1634 << initValue;
1635
1636 // Check that the type of the initial value is compatible with the type of
1637 // the global variable.
1638 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1639 // Check the element types match.
1640 auto initElementType =
1641 cast<TensorType>(elementsAttr.getType()).getElementType();
1642 auto memrefElementType = memrefType.getElementType();
1643
1644 if (initElementType != memrefElementType)
1645 return emitOpError("initial value element expected to be of type ")
1646 << memrefElementType << ", but was of type " << initElementType;
1647
1648 // Check the shapes match, given that memref globals can only produce
1649 // statically shaped memrefs and elements literal type must have a static
1650 // shape we can assume both types are shaped.
1651 auto initShape = elementsAttr.getShapedType().getShape();
1652 auto memrefShape = memrefType.getShape();
1653 if (initShape != memrefShape)
1654 return emitOpError("initial value shape expected to be ")
1655 << memrefShape << " but was " << initShape;
1656 }
1657 }
1658
1659 // TODO: verify visibility for declarations.
1660 return success();
1661}
1662
1663ElementsAttr GlobalOp::getConstantInitValue() {
1664 auto initVal = getInitialValue();
1665 if (getConstant() && initVal.has_value())
1666 return llvm::cast<ElementsAttr>(initVal.value());
1667 return {};
1668}
1669
1670//===----------------------------------------------------------------------===//
1671// GetGlobalOp
1672//===----------------------------------------------------------------------===//
1673
1674LogicalResult
1675GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1676 // Verify that the result type is same as the type of the referenced
1677 // memref.global op.
1678 auto global =
1679 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1680 if (!global)
1681 return emitOpError("'")
1682 << getName() << "' does not reference a valid global memref";
1683
1684 Type resultType = getResult().getType();
1685 if (global.getType() != resultType)
1686 return emitOpError("result type ")
1687 << resultType << " does not match type " << global.getType()
1688 << " of the global memref @" << getName();
1689 return success();
1690}
1691
1692//===----------------------------------------------------------------------===//
1693// LoadOp
1694//===----------------------------------------------------------------------===//
1695
1696LogicalResult LoadOp::verify() {
1697 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1698 return emitOpError("incorrect number of indices for load, expected ")
1699 << getMemRefType().getRank() << " but got " << getIndices().size();
1700 }
1701 return success();
1702}
1703
1704OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1705 /// load(memrefcast) -> load
1706 if (succeeded(foldMemRefCast(*this)))
1707 return getResult();
1708
1709 // Fold load from a global constant memref.
1710 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1711 if (!getGlobalOp)
1712 return {};
1713
1714 // Get to the memref.global defining the symbol.
1716 getGlobalOp, getGlobalOp.getNameAttr());
1717 if (!global)
1718 return {};
1719 // If it's a splat constant, we can fold irrespective of indices.
1720 auto splatAttr =
1721 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1722 if (!splatAttr)
1723 return {};
1724
1725 return splatAttr.getSplatValue<Attribute>();
1726}
1727
1728FailureOr<std::optional<SmallVector<Value>>>
1729LoadOp::bubbleDownCasts(OpBuilder &builder) {
1731 getResult());
1732}
1733
1734//===----------------------------------------------------------------------===//
1735// MemorySpaceCastOp
1736//===----------------------------------------------------------------------===//
1737
1738void MemorySpaceCastOp::getAsmResultNames(
1739 function_ref<void(Value, StringRef)> setNameFn) {
1740 setNameFn(getResult(), "memspacecast");
1741}
1742
1743bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1744 if (inputs.size() != 1 || outputs.size() != 1)
1745 return false;
1746 Type a = inputs.front(), b = outputs.front();
1747 auto aT = llvm::dyn_cast<MemRefType>(a);
1748 auto bT = llvm::dyn_cast<MemRefType>(b);
1749
1750 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1751 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1752
1753 if (aT && bT) {
1754 if (aT.getElementType() != bT.getElementType())
1755 return false;
1756 if (aT.getLayout() != bT.getLayout())
1757 return false;
1758 if (aT.getShape() != bT.getShape())
1759 return false;
1760 return true;
1761 }
1762 if (uaT && ubT) {
1763 return uaT.getElementType() == ubT.getElementType();
1764 }
1765 return false;
1766}
1767
1768OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1769 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1770 // t2)
1771 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1772 getSourceMutable().assign(parentCast.getSource());
1773 return getResult();
1774 }
1775 return Value{};
1776}
1777
1778TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1779 return getSource();
1780}
1781
1782TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1783 return getDest();
1784}
1785
1786bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1787 PtrLikeTypeInterface src) {
1788 return isa<BaseMemRefType>(tgt) &&
1789 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1790}
1791
1792MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1793 OpBuilder &b, PtrLikeTypeInterface tgt,
1795 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1796 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1797}
1798
1799/// The only cast we recognize as promotable is to the generic space.
1800bool MemorySpaceCastOp::isSourcePromotable() {
1801 return getDest().getType().getMemorySpace() == nullptr;
1802}
1803
1804//===----------------------------------------------------------------------===//
1805// PrefetchOp
1806//===----------------------------------------------------------------------===//
1807
1808void PrefetchOp::print(OpAsmPrinter &p) {
1809 p << " " << getMemref() << '[';
1811 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1812 p << ", locality<" << getLocalityHint();
1813 p << ">, " << (getIsDataCache() ? "data" : "instr");
1815 (*this)->getAttrs(),
1816 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1817 p << " : " << getMemRefType();
1818}
1819
1820ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1823 IntegerAttr localityHint;
1824 MemRefType type;
1825 StringRef readOrWrite, cacheType;
1826
1827 auto indexTy = parser.getBuilder().getIndexType();
1828 auto i32Type = parser.getBuilder().getIntegerType(32);
1829 if (parser.parseOperand(memrefInfo) ||
1831 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1832 parser.parseComma() || parser.parseKeyword("locality") ||
1833 parser.parseLess() ||
1834 parser.parseAttribute(localityHint, i32Type, "localityHint",
1835 result.attributes) ||
1836 parser.parseGreater() || parser.parseComma() ||
1837 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1838 parser.resolveOperand(memrefInfo, type, result.operands) ||
1839 parser.resolveOperands(indexInfo, indexTy, result.operands))
1840 return failure();
1841
1842 if (readOrWrite != "read" && readOrWrite != "write")
1843 return parser.emitError(parser.getNameLoc(),
1844 "rw specifier has to be 'read' or 'write'");
1845 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1846 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1847
1848 if (cacheType != "data" && cacheType != "instr")
1849 return parser.emitError(parser.getNameLoc(),
1850 "cache type has to be 'data' or 'instr'");
1851
1852 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1853 parser.getBuilder().getBoolAttr(cacheType == "data"));
1854
1855 return success();
1856}
1857
1858LogicalResult PrefetchOp::verify() {
1859 if (getNumOperands() != 1 + getMemRefType().getRank())
1860 return emitOpError("too few indices");
1861
1862 return success();
1863}
1864
1865LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1867 // prefetch(memrefcast) -> prefetch
1868 return foldMemRefCast(*this);
1869}
1870
1871//===----------------------------------------------------------------------===//
1872// RankOp
1873//===----------------------------------------------------------------------===//
1874
1875OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1876 // Constant fold rank when the rank of the operand is known.
1877 auto type = getOperand().getType();
1878 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1879 if (shapedType && shapedType.hasRank())
1880 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1881 return IntegerAttr();
1882}
1883
1884//===----------------------------------------------------------------------===//
1885// ReinterpretCastOp
1886//===----------------------------------------------------------------------===//
1887
1888void ReinterpretCastOp::getAsmResultNames(
1889 function_ref<void(Value, StringRef)> setNameFn) {
1890 setNameFn(getResult(), "reinterpret_cast");
1891}
1892
1893/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1894/// `staticSizes` and `staticStrides` are automatically filled with
1895/// source-memref-rank sentinel values that encode dynamic entries.
1896void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1897 MemRefType resultType, Value source,
1899 ArrayRef<OpFoldResult> strides,
1901 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1902 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1903 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1904 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1905 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1906 result.addAttributes(attrs);
1907 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1908 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1909 b.getDenseI64ArrayAttr(staticSizes),
1910 b.getDenseI64ArrayAttr(staticStrides));
1911}
1912
1913void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1914 Value source, OpFoldResult offset,
1916 ArrayRef<OpFoldResult> strides,
1918 auto sourceType = cast<BaseMemRefType>(source.getType());
1919 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1920 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1921 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1922 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1923 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1924 auto stridedLayout = StridedLayoutAttr::get(
1925 b.getContext(), staticOffsets.front(), staticStrides);
1926 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1927 stridedLayout, sourceType.getMemorySpace());
1928 build(b, result, resultType, source, offset, sizes, strides, attrs);
1929}
1930
1931void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1932 MemRefType resultType, Value source,
1933 int64_t offset, ArrayRef<int64_t> sizes,
1934 ArrayRef<int64_t> strides,
1936 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
1937 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
1938 SmallVector<OpFoldResult> strideValues =
1939 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
1940 return b.getI64IntegerAttr(v);
1941 });
1942 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1943 strideValues, attrs);
1944}
1945
1946void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1947 MemRefType resultType, Value source, Value offset,
1948 ValueRange sizes, ValueRange strides,
1950 SmallVector<OpFoldResult> sizeValues =
1951 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
1952 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
1953 strides, [](Value v) -> OpFoldResult { return v; });
1954 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1955}
1956
1957// TODO: ponder whether we want to allow missing trailing sizes/strides that are
1958// completed automatically, like we have for subview and extract_slice.
1959LogicalResult ReinterpretCastOp::verify() {
1960 // The source and result memrefs should be in the same memory space.
1961 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1962 auto resultType = llvm::cast<MemRefType>(getType());
1963 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1964 return emitError("different memory spaces specified for source type ")
1965 << srcType << " and result memref type " << resultType;
1966 if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
1967 "result")))
1968 return failure();
1969
1970 // Match sizes in result memref type and in static_sizes attribute.
1971 for (auto [idx, resultSize, expectedSize] :
1972 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1973 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1974 return emitError("expected result type with size = ")
1975 << (ShapedType::isDynamic(expectedSize)
1976 ? std::string("dynamic")
1977 : std::to_string(expectedSize))
1978 << " instead of " << resultSize << " in dim = " << idx;
1979 }
1980
1981 // Match offset and strides in static_offset and static_strides attributes. If
1982 // result memref type has no affine map specified, this will assume an
1983 // identity layout.
1984 int64_t resultOffset;
1985 SmallVector<int64_t, 4> resultStrides;
1986 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1987 return emitError("expected result type to have strided layout but found ")
1988 << resultType;
1989
1990 // Match offset in result memref type and in static_offsets attribute.
1991 int64_t expectedOffset = getStaticOffsets().front();
1992 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1993 return emitError("expected result type with offset = ")
1994 << (ShapedType::isDynamic(expectedOffset)
1995 ? std::string("dynamic")
1996 : std::to_string(expectedOffset))
1997 << " instead of " << resultOffset;
1998
1999 // Match strides in result memref type and in static_strides attribute.
2000 for (auto [idx, resultStride, expectedStride] :
2001 llvm::enumerate(resultStrides, getStaticStrides())) {
2002 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2003 return emitError("expected result type with stride = ")
2004 << (ShapedType::isDynamic(expectedStride)
2005 ? std::string("dynamic")
2006 : std::to_string(expectedStride))
2007 << " instead of " << resultStride << " in dim = " << idx;
2008 }
2009
2010 return success();
2011}
2012
2013OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
2014 Value src = getSource();
2015 auto getPrevSrc = [&]() -> Value {
2016 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
2017 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
2018 return prev.getSource();
2019
2020 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
2021 if (auto prev = src.getDefiningOp<CastOp>())
2022 return prev.getSource();
2023
2024 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2025 // are 0.
2026 if (auto prev = src.getDefiningOp<SubViewOp>())
2027 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2028 return prev.getSource();
2029
2030 return nullptr;
2031 };
2032
2033 if (auto prevSrc = getPrevSrc()) {
2034 getSourceMutable().assign(prevSrc);
2035 return getResult();
2036 }
2037
2038 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2039 if (ShapedType::isStaticShape(getType().getShape()) &&
2040 src.getType() == getType() && getStaticOffsets().front() == 0) {
2041 return src;
2042 }
2043
2044 return nullptr;
2045}
2046
2047SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2050 return values;
2051}
2052
2053SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2054 SmallVector<OpFoldResult> values = getMixedStrides();
2055 SmallVector<int64_t> staticValues;
2056 int64_t unused;
2057 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2058 (void)status;
2059 assert(succeeded(status) && "could not get strides from type");
2060 constifyIndexValues(values, staticValues);
2061 return values;
2062}
2063
2064OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2065 SmallVector<OpFoldResult> values = getMixedOffsets();
2066 assert(values.size() == 1 &&
2067 "reinterpret_cast must have one and only one offset");
2068 SmallVector<int64_t> staticValues, unused;
2069 int64_t offset;
2070 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2071 (void)status;
2072 assert(succeeded(status) && "could not get offset from type");
2073 staticValues.push_back(offset);
2074 constifyIndexValues(values, staticValues);
2075 return values[0];
2076}
2077
2078namespace {
2079/// Replace the sequence:
2080/// ```
2081/// base, offset, sizes, strides = extract_strided_metadata src
2082/// dst = reinterpret_cast base to offset, sizes, strides
2083/// ```
2084/// With
2085///
2086/// ```
2087/// dst = memref.cast src
2088/// ```
2089///
2090/// Note: The cast operation is only inserted when the type of dst and src
2091/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2092///
2093/// This pattern also matches when the offset, sizes, and strides don't come
2094/// directly from the `extract_strided_metadata`'s results but it can be
2095/// statically proven that they would hold the same values.
2096///
2097/// For instance, the following sequence would be replaced:
2098/// ```
2099/// base, offset, sizes, strides =
2100/// extract_strided_metadata memref : memref<3x4xty>
2101/// dst = reinterpret_cast base to 0, [3, 4], strides
2102/// ```
2103/// Because we know (thanks to the type of the input memref) that variable
2104/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2105///
2106/// Similarly, the following sequence would be replaced:
2107/// ```
2108/// c0 = arith.constant 0
2109/// c4 = arith.constant 4
2110/// base, offset, sizes, strides =
2111/// extract_strided_metadata memref : memref<3x4xty>
2112/// dst = reinterpret_cast base to c0, [3, c4], strides
2113/// ```
2114/// Because we know that `offset`and `c0` will hold 0
2115/// and `c4` will hold 4.
2116///
2117/// If the pattern above does not match, the input of the
2118/// extract_strided_metadata is always folded into the input of the
2119/// reinterpret_cast operator. This allows for dead code elimination to get rid
2120/// of the extract_strided_metadata in some cases.
2121struct ReinterpretCastOpExtractStridedMetadataFolder
2122 : public OpRewritePattern<ReinterpretCastOp> {
2123public:
2124 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2125
2126 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2127 PatternRewriter &rewriter) const override {
2128 auto extractStridedMetadata =
2129 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2130 if (!extractStridedMetadata)
2131 return failure();
2132
2133 // Check if the reinterpret cast reconstructs a memref with the exact same
2134 // properties as the extract strided metadata.
2135 auto isReinterpretCastNoop = [&]() -> bool {
2136 // First, check that the strides are the same.
2137 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2138 op.getConstifiedMixedStrides()))
2139 return false;
2140
2141 // Second, check the sizes.
2142 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2143 op.getConstifiedMixedSizes()))
2144 return false;
2145
2146 // Finally, check the offset.
2147 assert(op.getMixedOffsets().size() == 1 &&
2148 "reinterpret_cast with more than one offset should have been "
2149 "rejected by the verifier");
2150 return extractStridedMetadata.getConstifiedMixedOffset() ==
2151 op.getConstifiedMixedOffset();
2152 };
2153
2154 if (!isReinterpretCastNoop()) {
2155 // If the extract_strided_metadata / reinterpret_cast pair can't be
2156 // completely folded, then we could fold the input of the
2157 // extract_strided_metadata into the input of the reinterpret_cast
2158 // input. For some cases (e.g., static dimensions) the
2159 // the extract_strided_metadata is eliminated by dead code elimination.
2160 //
2161 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2162 //
2163 // We can always fold the input of a extract_strided_metadata operator
2164 // to the input of a reinterpret_cast operator, because they point to
2165 // the same memory. Note that the reinterpret_cast does not use the
2166 // layout of its input memref, only its base memory pointer which is
2167 // the same as the base pointer returned by the extract_strided_metadata
2168 // operator and the base pointer of the extract_strided_metadata memref
2169 // input.
2170 rewriter.modifyOpInPlace(op, [&]() {
2171 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2172 });
2173 return success();
2174 }
2175
2176 // At this point, we know that the back and forth between extract strided
2177 // metadata and reinterpret cast is a noop. However, the final type of the
2178 // reinterpret cast may not be exactly the same as the original memref.
2179 // E.g., it could be changing a dimension from static to dynamic. Check that
2180 // here and add a cast if necessary.
2181 Type srcTy = extractStridedMetadata.getSource().getType();
2182 if (srcTy == op.getResult().getType())
2183 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2184 else
2185 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2186 extractStridedMetadata.getSource());
2187
2188 return success();
2189 }
2190};
2191
2192struct ReinterpretCastOpConstantFolder
2193 : public OpRewritePattern<ReinterpretCastOp> {
2194public:
2195 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2196
2197 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2198 PatternRewriter &rewriter) const override {
2199 unsigned srcStaticCount = llvm::count_if(
2200 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2201 op.getMixedStrides()),
2202 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2203
2204 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2205 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2206 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2207
2208 // TODO: Using counting comparison instead of direct comparison because
2209 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2210 // IntegerAttrs, while constifyIndexValues (and therefore
2211 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2212 if (srcStaticCount ==
2213 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2214 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2215 return failure();
2216
2217 auto newReinterpretCast = ReinterpretCastOp::create(
2218 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2219
2220 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2221 return success();
2222 }
2223};
2224} // namespace
2225
2226void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2227 MLIRContext *context) {
2228 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2229 ReinterpretCastOpConstantFolder>(context);
2230}
2231
2232FailureOr<std::optional<SmallVector<Value>>>
2233ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2234 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2235}
2236
2237//===----------------------------------------------------------------------===//
2238// Reassociative reshape ops
2239//===----------------------------------------------------------------------===//
2240
2241void CollapseShapeOp::getAsmResultNames(
2242 function_ref<void(Value, StringRef)> setNameFn) {
2243 setNameFn(getResult(), "collapse_shape");
2244}
2245
2246void ExpandShapeOp::getAsmResultNames(
2247 function_ref<void(Value, StringRef)> setNameFn) {
2248 setNameFn(getResult(), "expand_shape");
2249}
2250
2251LogicalResult ExpandShapeOp::reifyResultShapes(
2252 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2253 reifiedResultShapes = {
2254 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2255 return success();
2256}
2257
2258/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2259/// result and operand. Layout maps are verified separately.
2260///
2261/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2262/// allowed in a reassocation group.
2263static LogicalResult
2265 ArrayRef<int64_t> expandedShape,
2266 ArrayRef<ReassociationIndices> reassociation,
2267 bool allowMultipleDynamicDimsPerGroup) {
2268 // There must be one reassociation group per collapsed dimension.
2269 if (collapsedShape.size() != reassociation.size())
2270 return op->emitOpError("invalid number of reassociation groups: found ")
2271 << reassociation.size() << ", expected " << collapsedShape.size();
2272
2273 // The next expected expanded dimension index (while iterating over
2274 // reassociation indices).
2275 int64_t nextDim = 0;
2276 for (const auto &it : llvm::enumerate(reassociation)) {
2277 ReassociationIndices group = it.value();
2278 int64_t collapsedDim = it.index();
2279
2280 bool foundDynamic = false;
2281 for (int64_t expandedDim : group) {
2282 if (expandedDim != nextDim++)
2283 return op->emitOpError("reassociation indices must be contiguous");
2284
2285 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2286 return op->emitOpError("reassociation index ")
2287 << expandedDim << " is out of bounds";
2288
2289 // Check if there are multiple dynamic dims in a reassociation group.
2290 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2291 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2292 return op->emitOpError(
2293 "at most one dimension in a reassociation group may be dynamic");
2294 foundDynamic = true;
2295 }
2296 }
2297
2298 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2299 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2300 return op->emitOpError("collapsed dim (")
2301 << collapsedDim
2302 << ") must be dynamic if and only if reassociation group is "
2303 "dynamic";
2304
2305 // If all dims in the reassociation group are static, the size of the
2306 // collapsed dim can be verified.
2307 if (!foundDynamic) {
2308 int64_t groupSize = 1;
2309 for (int64_t expandedDim : group)
2310 groupSize *= expandedShape[expandedDim];
2311 if (groupSize != collapsedShape[collapsedDim])
2312 return op->emitOpError("collapsed dim size (")
2313 << collapsedShape[collapsedDim]
2314 << ") must equal reassociation group size (" << groupSize << ")";
2315 }
2316 }
2317
2318 if (collapsedShape.empty()) {
2319 // Rank 0: All expanded dimensions must be 1.
2320 for (int64_t d : expandedShape)
2321 if (d != 1)
2322 return op->emitOpError(
2323 "rank 0 memrefs can only be extended/collapsed with/from ones");
2324 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2325 // Rank >= 1: Number of dimensions among all reassociation groups must match
2326 // the result memref rank.
2327 return op->emitOpError("expanded rank (")
2328 << expandedShape.size()
2329 << ") inconsistent with number of reassociation indices (" << nextDim
2330 << ")";
2331 }
2332
2333 return success();
2334}
2335
2336SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2337 return getSymbolLessAffineMaps(getReassociationExprs());
2338}
2339
2340SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2342 getReassociationIndices());
2343}
2344
2345SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2346 return getSymbolLessAffineMaps(getReassociationExprs());
2347}
2348
2349SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2351 getReassociationIndices());
2352}
2353
2354/// Compute the layout map after expanding a given source MemRef type with the
2355/// specified reassociation indices.
2356static FailureOr<StridedLayoutAttr>
2357computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2358 ArrayRef<ReassociationIndices> reassociation) {
2359 int64_t srcOffset;
2360 SmallVector<int64_t> srcStrides;
2361 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2362 return failure();
2363 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2364
2365 // 1-1 mapping between srcStrides and reassociation packs.
2366 // Each srcStride starts with the given value and gets expanded according to
2367 // the proper entries in resultShape.
2368 // Example:
2369 // srcStrides = [10000, 1 , 100 ],
2370 // reassociations = [ [0], [1], [2, 3, 4]],
2371 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2372 // -> For the purpose of stride calculation, the useful sizes are:
2373 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2374 // resultStrides = [10000, 1, 600, 200, 100]
2375 // Note that a stride does not get expanded along the first entry of each
2376 // shape pack.
2377 SmallVector<int64_t> reverseResultStrides;
2378 reverseResultStrides.reserve(resultShape.size());
2379 unsigned shapeIndex = resultShape.size() - 1;
2380 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2381 ReassociationIndices reassoc = std::get<0>(it);
2382 int64_t currentStrideToExpand = std::get<1>(it);
2383 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2384 reverseResultStrides.push_back(currentStrideToExpand);
2385 currentStrideToExpand =
2386 (SaturatedInteger::wrap(currentStrideToExpand) *
2387 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2388 .asInteger();
2389 }
2390 }
2391 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2392 resultStrides.resize(resultShape.size(), 1);
2393 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2394}
2395
2396FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2397 MemRefType srcType, ArrayRef<int64_t> resultShape,
2398 ArrayRef<ReassociationIndices> reassociation) {
2399 if (srcType.getLayout().isIdentity()) {
2400 // If the source is contiguous (i.e., no layout map specified), so is the
2401 // result.
2402 MemRefLayoutAttrInterface layout;
2403 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2404 srcType.getMemorySpace());
2405 }
2406
2407 // Source may not be contiguous. Compute the layout map.
2408 FailureOr<StridedLayoutAttr> computedLayout =
2409 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2410 if (failed(computedLayout))
2411 return failure();
2412 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2413 srcType.getMemorySpace());
2414}
2415
2416FailureOr<SmallVector<OpFoldResult>>
2417ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2418 MemRefType expandedType,
2419 ArrayRef<ReassociationIndices> reassociation,
2420 ArrayRef<OpFoldResult> inputShape) {
2421 std::optional<SmallVector<OpFoldResult>> outputShape =
2422 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2423 inputShape);
2424 if (!outputShape)
2425 return failure();
2426 return *outputShape;
2427}
2428
2429void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2430 Type resultType, Value src,
2431 ArrayRef<ReassociationIndices> reassociation,
2432 ArrayRef<OpFoldResult> outputShape) {
2433 auto [staticOutputShape, dynamicOutputShape] =
2434 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2435 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2436 getReassociationIndicesAttribute(builder, reassociation),
2437 dynamicOutputShape, staticOutputShape);
2438}
2439
2440void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2441 Type resultType, Value src,
2442 ArrayRef<ReassociationIndices> reassociation) {
2443 SmallVector<OpFoldResult> inputShape =
2444 getMixedSizes(builder, result.location, src);
2445 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2446 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2447 builder, result.location, memrefResultTy, reassociation, inputShape);
2448 // Failure of this assertion usually indicates presence of multiple
2449 // dynamic dimensions in the same reassociation group.
2450 assert(succeeded(outputShape) && "unable to infer output shape");
2451 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2452}
2453
2454void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2455 ArrayRef<int64_t> resultShape, Value src,
2456 ArrayRef<ReassociationIndices> reassociation) {
2457 // Only ranked memref source values are supported.
2458 auto srcType = llvm::cast<MemRefType>(src.getType());
2459 FailureOr<MemRefType> resultType =
2460 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2461 // Failure of this assertion usually indicates a problem with the source
2462 // type, e.g., could not get strides/offset.
2463 assert(succeeded(resultType) && "could not compute layout");
2464 build(builder, result, *resultType, src, reassociation);
2465}
2466
2467void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2468 ArrayRef<int64_t> resultShape, Value src,
2469 ArrayRef<ReassociationIndices> reassociation,
2470 ArrayRef<OpFoldResult> outputShape) {
2471 // Only ranked memref source values are supported.
2472 auto srcType = llvm::cast<MemRefType>(src.getType());
2473 FailureOr<MemRefType> resultType =
2474 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2475 // Failure of this assertion usually indicates a problem with the source
2476 // type, e.g., could not get strides/offset.
2477 assert(succeeded(resultType) && "could not compute layout");
2478 build(builder, result, *resultType, src, reassociation, outputShape);
2479}
2480
2481LogicalResult ExpandShapeOp::verify() {
2482 MemRefType srcType = getSrcType();
2483 MemRefType resultType = getResultType();
2484
2485 if (srcType.getRank() > resultType.getRank()) {
2486 auto r0 = srcType.getRank();
2487 auto r1 = resultType.getRank();
2488 return emitOpError("has source rank ")
2489 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2490 << r0 << " > " << r1 << ").";
2491 }
2492
2493 // Verify result shape.
2494 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2495 resultType.getShape(),
2496 getReassociationIndices(),
2497 /*allowMultipleDynamicDimsPerGroup=*/true)))
2498 return failure();
2499
2500 // Compute expected result type (including layout map).
2501 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2502 srcType, resultType.getShape(), getReassociationIndices());
2503 if (failed(expectedResultType))
2504 return emitOpError("invalid source layout map");
2505
2506 // Check actual result type.
2507 if (*expectedResultType != resultType)
2508 return emitOpError("expected expanded type to be ")
2509 << *expectedResultType << " but found " << resultType;
2510
2511 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2512 return emitOpError("expected number of static shape bounds to be equal to "
2513 "the output rank (")
2514 << resultType.getRank() << ") but found "
2515 << getStaticOutputShape().size() << " inputs instead";
2516
2517 if ((int64_t)getOutputShape().size() !=
2518 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2519 return emitOpError("mismatch in dynamic dims in output_shape and "
2520 "static_output_shape: static_output_shape has ")
2521 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2522 << " dynamic dims while output_shape has " << getOutputShape().size()
2523 << " values";
2524
2525 // Verify if provided output shapes are in agreement with output type.
2526 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2527 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2528 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2529 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2530 return emitOpError("invalid output shape provided at pos ") << pos;
2531 }
2532 }
2533
2534 return success();
2535}
2536
2537struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
2538public:
2539 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2540
2541 LogicalResult matchAndRewrite(ExpandShapeOp op,
2542 PatternRewriter &rewriter) const override {
2543 auto cast = op.getSrc().getDefiningOp<CastOp>();
2544 if (!cast)
2545 return failure();
2546
2547 if (!CastOp::canFoldIntoConsumerOp(cast))
2548 return failure();
2549
2550 SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
2551 SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
2552 SmallVector<int64_t> newOutputShapeSizes;
2553
2554 // Convert output shape dims from dynamic to static where possible.
2555 for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2556 std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
2557 if (!sizeOpt.has_value()) {
2558 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2559 continue;
2560 }
2561
2562 newOutputShapeSizes.push_back(sizeOpt.value());
2563 newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
2564 }
2565
2566 Value castSource = cast.getSource();
2567 auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
2568 SmallVector<ReassociationIndices> reassociationIndices =
2569 op.getReassociationIndices();
2570 for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2571 auto newOutputShapeSizesSlice =
2572 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2573 bool newOutputDynamic =
2574 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2575 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2576 return rewriter.notifyMatchFailure(
2577 op, "folding cast will result in changing dynamicity in "
2578 "reassociation group");
2579 }
2580
2581 FailureOr<MemRefType> newResultTypeOrFailure =
2582 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2583 reassociationIndices);
2584
2585 if (failed(newResultTypeOrFailure))
2586 return rewriter.notifyMatchFailure(
2587 op, "could not compute new expanded type after folding cast");
2588
2589 if (*newResultTypeOrFailure == op.getResultType()) {
2590 rewriter.modifyOpInPlace(
2591 op, [&]() { op.getSrcMutable().assign(castSource); });
2592 } else {
2593 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2594 *newResultTypeOrFailure, castSource,
2595 reassociationIndices, newOutputShape);
2596 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2597 }
2598 return success();
2599 }
2600};
2601
2602void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2603 MLIRContext *context) {
2604 results.add<
2605 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2606 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2607 ExpandShapeOpMemRefCastFolder>(context);
2608}
2609
2610FailureOr<std::optional<SmallVector<Value>>>
2611ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2612 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2613}
2614
2615/// Compute the layout map after collapsing a given source MemRef type with the
2616/// specified reassociation indices.
2617///
2618/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2619/// not possible to check this by inspecting a MemRefType in the general case.
2620/// If non-contiguity cannot be checked statically, the collapse is assumed to
2621/// be valid (and thus accepted by this function) unless `strict = true`.
2622static FailureOr<StridedLayoutAttr>
2623computeCollapsedLayoutMap(MemRefType srcType,
2624 ArrayRef<ReassociationIndices> reassociation,
2625 bool strict = false) {
2626 int64_t srcOffset;
2627 SmallVector<int64_t> srcStrides;
2628 auto srcShape = srcType.getShape();
2629 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2630 return failure();
2631
2632 // The result stride of a reassociation group is the stride of the last entry
2633 // of the reassociation. (TODO: Should be the minimum stride in the
2634 // reassociation because strides are not necessarily sorted. E.g., when using
2635 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2636 // strides are meaningless and could have any arbitrary value.
2637 SmallVector<int64_t> resultStrides;
2638 resultStrides.reserve(reassociation.size());
2639 for (const ReassociationIndices &reassoc : reassociation) {
2640 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2641 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2642 ref = ref.drop_back();
2643 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2644 resultStrides.push_back(srcStrides[ref.back()]);
2645 } else {
2646 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2647 // the corresponding stride may have to be skipped. (See above comment.)
2648 // Therefore, the result stride cannot be statically determined and must
2649 // be dynamic.
2650 resultStrides.push_back(ShapedType::kDynamic);
2651 }
2652 }
2653
2654 // Validate that each reassociation group is contiguous.
2655 unsigned resultStrideIndex = resultStrides.size() - 1;
2656 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2657 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2658 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2659 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2660 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2661
2662 // Dimensions of size 1 should be skipped, because their strides are
2663 // meaningless and could have any arbitrary value.
2664 if (srcShape[idx - 1] == 1)
2665 continue;
2666
2667 // Both source and result stride must have the same static value. In that
2668 // case, we can be sure, that the dimensions are collapsible (because they
2669 // are contiguous).
2670 // If `strict = false` (default during op verification), we accept cases
2671 // where one or both strides are dynamic. This is best effort: We reject
2672 // ops where obviously non-contiguous dims are collapsed, but accept ops
2673 // where we cannot be sure statically. Such ops may fail at runtime. See
2674 // the op documentation for details.
2675 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2676 if (strict && (stride.saturated || srcStride.saturated))
2677 return failure();
2678
2679 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2680 return failure();
2681 }
2682 }
2683 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2684}
2685
2686bool CollapseShapeOp::isGuaranteedCollapsible(
2687 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2688 // MemRefs with identity layout are always collapsible.
2689 if (srcType.getLayout().isIdentity())
2690 return true;
2691
2692 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2693 /*strict=*/true));
2694}
2695
2696MemRefType CollapseShapeOp::computeCollapsedType(
2697 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2698 SmallVector<int64_t> resultShape;
2699 resultShape.reserve(reassociation.size());
2700 for (const ReassociationIndices &group : reassociation) {
2701 auto groupSize = SaturatedInteger::wrap(1);
2702 for (int64_t srcDim : group)
2703 groupSize =
2704 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2705 resultShape.push_back(groupSize.asInteger());
2706 }
2707
2708 if (srcType.getLayout().isIdentity()) {
2709 // If the source is contiguous (i.e., no layout map specified), so is the
2710 // result.
2711 MemRefLayoutAttrInterface layout;
2712 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2713 srcType.getMemorySpace());
2714 }
2715
2716 // Source may not be fully contiguous. Compute the layout map.
2717 // Note: Dimensions that are collapsed into a single dim are assumed to be
2718 // contiguous.
2719 FailureOr<StridedLayoutAttr> computedLayout =
2720 computeCollapsedLayoutMap(srcType, reassociation);
2721 assert(succeeded(computedLayout) &&
2722 "invalid source layout map or collapsing non-contiguous dims");
2723 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2724 srcType.getMemorySpace());
2725}
2726
2727void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2728 ArrayRef<ReassociationIndices> reassociation,
2729 ArrayRef<NamedAttribute> attrs) {
2730 auto srcType = llvm::cast<MemRefType>(src.getType());
2731 MemRefType resultType =
2732 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2734 getReassociationIndicesAttribute(b, reassociation));
2735 build(b, result, resultType, src, attrs);
2736}
2737
2738LogicalResult CollapseShapeOp::verify() {
2739 MemRefType srcType = getSrcType();
2740 MemRefType resultType = getResultType();
2741
2742 if (srcType.getRank() < resultType.getRank()) {
2743 auto r0 = srcType.getRank();
2744 auto r1 = resultType.getRank();
2745 return emitOpError("has source rank ")
2746 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2747 << r0 << " < " << r1 << ").";
2748 }
2749
2750 // Verify result shape.
2751 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2752 srcType.getShape(), getReassociationIndices(),
2753 /*allowMultipleDynamicDimsPerGroup=*/true)))
2754 return failure();
2755
2756 // Compute expected result type (including layout map).
2757 MemRefType expectedResultType;
2758 if (srcType.getLayout().isIdentity()) {
2759 // If the source is contiguous (i.e., no layout map specified), so is the
2760 // result.
2761 MemRefLayoutAttrInterface layout;
2762 expectedResultType =
2763 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2764 srcType.getMemorySpace());
2765 } else {
2766 // Source may not be fully contiguous. Compute the layout map.
2767 // Note: Dimensions that are collapsed into a single dim are assumed to be
2768 // contiguous.
2769 FailureOr<StridedLayoutAttr> computedLayout =
2770 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2771 if (failed(computedLayout))
2772 return emitOpError(
2773 "invalid source layout map or collapsing non-contiguous dims");
2774 expectedResultType =
2775 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2776 *computedLayout, srcType.getMemorySpace());
2777 }
2778
2779 if (expectedResultType != resultType)
2780 return emitOpError("expected collapsed type to be ")
2781 << expectedResultType << " but found " << resultType;
2782
2783 return success();
2784}
2785
2787 : public OpRewritePattern<CollapseShapeOp> {
2788public:
2789 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2790
2791 LogicalResult matchAndRewrite(CollapseShapeOp op,
2792 PatternRewriter &rewriter) const override {
2793 auto cast = op.getOperand().getDefiningOp<CastOp>();
2794 if (!cast)
2795 return failure();
2796
2797 if (!CastOp::canFoldIntoConsumerOp(cast))
2798 return failure();
2799
2800 Type newResultType = CollapseShapeOp::computeCollapsedType(
2801 llvm::cast<MemRefType>(cast.getOperand().getType()),
2802 op.getReassociationIndices());
2803
2804 if (newResultType == op.getResultType()) {
2805 rewriter.modifyOpInPlace(
2806 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2807 } else {
2808 Value newOp =
2809 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2810 op.getReassociationIndices());
2811 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2812 }
2813 return success();
2814 }
2815};
2816
2817void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2818 MLIRContext *context) {
2819 results.add<
2820 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2821 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2822 memref::DimOp, MemRefType>,
2823 CollapseShapeOpMemRefCastFolder>(context);
2824}
2825
2826OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2828 adaptor.getOperands());
2829}
2830
2831OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2833 adaptor.getOperands());
2834}
2835
2836FailureOr<std::optional<SmallVector<Value>>>
2837CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2838 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2839}
2840
2841//===----------------------------------------------------------------------===//
2842// ReshapeOp
2843//===----------------------------------------------------------------------===//
2844
2845void ReshapeOp::getAsmResultNames(
2846 function_ref<void(Value, StringRef)> setNameFn) {
2847 setNameFn(getResult(), "reshape");
2848}
2849
2850LogicalResult ReshapeOp::verify() {
2851 Type operandType = getSource().getType();
2852 Type resultType = getResult().getType();
2853
2854 Type operandElementType =
2855 llvm::cast<ShapedType>(operandType).getElementType();
2856 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2857 if (operandElementType != resultElementType)
2858 return emitOpError("element types of source and destination memref "
2859 "types should be the same");
2860
2861 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2862 if (!operandMemRefType.getLayout().isIdentity())
2863 return emitOpError("source memref type should have identity affine map");
2864
2865 int64_t shapeSize =
2866 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2867 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2868 if (resultMemRefType) {
2869 if (!resultMemRefType.getLayout().isIdentity())
2870 return emitOpError("result memref type should have identity affine map");
2871 if (shapeSize == ShapedType::kDynamic)
2872 return emitOpError("cannot use shape operand with dynamic length to "
2873 "reshape to statically-ranked memref type");
2874 if (shapeSize != resultMemRefType.getRank())
2875 return emitOpError(
2876 "length of shape operand differs from the result's memref rank");
2877 }
2878 return success();
2879}
2880
2881FailureOr<std::optional<SmallVector<Value>>>
2882ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2883 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2884}
2885
2886//===----------------------------------------------------------------------===//
2887// StoreOp
2888//===----------------------------------------------------------------------===//
2889
2890LogicalResult StoreOp::verify() {
2891 if (getNumOperands() != 2 + getMemRefType().getRank())
2892 return emitOpError("store index operand count not equal to memref rank");
2893
2894 return success();
2895}
2896
2897LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2898 SmallVectorImpl<OpFoldResult> &results) {
2899 /// store(memrefcast) -> store
2900 return foldMemRefCast(*this, getValueToStore());
2901}
2902
2903FailureOr<std::optional<SmallVector<Value>>>
2904StoreOp::bubbleDownCasts(OpBuilder &builder) {
2906 ValueRange());
2907}
2908
2909//===----------------------------------------------------------------------===//
2910// SubViewOp
2911//===----------------------------------------------------------------------===//
2912
2913void SubViewOp::getAsmResultNames(
2914 function_ref<void(Value, StringRef)> setNameFn) {
2915 setNameFn(getResult(), "subview");
2916}
2917
2918/// A subview result type can be fully inferred from the source type and the
2919/// static representation of offsets, sizes and strides. Special sentinels
2920/// encode the dynamic case.
2921MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2922 ArrayRef<int64_t> staticOffsets,
2923 ArrayRef<int64_t> staticSizes,
2924 ArrayRef<int64_t> staticStrides) {
2925 unsigned rank = sourceMemRefType.getRank();
2926 (void)rank;
2927 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2928 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2929 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2930
2931 // Extract source offset and strides.
2932 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2933
2934 // Compute target offset whose value is:
2935 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2936 int64_t targetOffset = sourceOffset;
2937 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2938 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2939 targetOffset = (SaturatedInteger::wrap(targetOffset) +
2940 SaturatedInteger::wrap(staticOffset) *
2941 SaturatedInteger::wrap(sourceStride))
2942 .asInteger();
2943 }
2944
2945 // Compute target stride whose value is:
2946 // `sourceStrides_i * staticStrides_i`.
2947 SmallVector<int64_t, 4> targetStrides;
2948 targetStrides.reserve(staticOffsets.size());
2949 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2950 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2951 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2952 SaturatedInteger::wrap(staticStride))
2953 .asInteger());
2954 }
2955
2956 // The type is now known.
2957 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2958 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2959 targetOffset, targetStrides),
2960 sourceMemRefType.getMemorySpace());
2961}
2962
2963MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2964 ArrayRef<OpFoldResult> offsets,
2965 ArrayRef<OpFoldResult> sizes,
2966 ArrayRef<OpFoldResult> strides) {
2967 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2968 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2969 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2970 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2971 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2972 if (!hasValidSizesOffsets(staticOffsets))
2973 return {};
2974 if (!hasValidSizesOffsets(staticSizes))
2975 return {};
2976 if (!hasValidStrides(staticStrides))
2977 return {};
2978 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2979 staticSizes, staticStrides);
2980}
2981
2982MemRefType SubViewOp::inferRankReducedResultType(
2983 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2984 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2985 ArrayRef<int64_t> strides) {
2986 MemRefType inferredType =
2987 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2988 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2989 "expected ");
2990 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2991 return inferredType;
2992
2993 // Compute which dimensions are dropped.
2994 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2995 computeRankReductionMask(inferredType.getShape(), resultShape);
2996 assert(dimsToProject.has_value() && "invalid rank reduction");
2997
2998 // Compute the layout and result type.
2999 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3000 SmallVector<int64_t> rankReducedStrides;
3001 rankReducedStrides.reserve(resultShape.size());
3002 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3003 if (!dimsToProject->contains(idx))
3004 rankReducedStrides.push_back(value);
3005 }
3006 return MemRefType::get(resultShape, inferredType.getElementType(),
3007 StridedLayoutAttr::get(inferredLayout.getContext(),
3008 inferredLayout.getOffset(),
3009 rankReducedStrides),
3010 inferredType.getMemorySpace());
3011}
3012
3013MemRefType SubViewOp::inferRankReducedResultType(
3014 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3015 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3016 ArrayRef<OpFoldResult> strides) {
3017 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3018 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3019 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3020 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3021 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3022 return SubViewOp::inferRankReducedResultType(
3023 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3024 staticStrides);
3025}
3026
3027// Build a SubViewOp with mixed static and dynamic entries and custom result
3028// type. If the type passed is nullptr, it is inferred.
3029void SubViewOp::build(OpBuilder &b, OperationState &result,
3030 MemRefType resultType, Value source,
3031 ArrayRef<OpFoldResult> offsets,
3032 ArrayRef<OpFoldResult> sizes,
3033 ArrayRef<OpFoldResult> strides,
3034 ArrayRef<NamedAttribute> attrs) {
3035 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3036 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3037 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3038 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3039 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3040 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
3041 // Structuring implementation this way avoids duplication between builders.
3042 if (!resultType) {
3043 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3044 staticSizes, staticStrides);
3045 }
3046 result.addAttributes(attrs);
3047 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
3048 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3049 b.getDenseI64ArrayAttr(staticSizes),
3050 b.getDenseI64ArrayAttr(staticStrides));
3051}
3052
3053// Build a SubViewOp with mixed static and dynamic entries and inferred result
3054// type.
3055void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3056 ArrayRef<OpFoldResult> offsets,
3057 ArrayRef<OpFoldResult> sizes,
3058 ArrayRef<OpFoldResult> strides,
3059 ArrayRef<NamedAttribute> attrs) {
3060 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3061}
3062
3063// Build a SubViewOp with static entries and inferred result type.
3064void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3065 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3066 ArrayRef<int64_t> strides,
3067 ArrayRef<NamedAttribute> attrs) {
3068 SmallVector<OpFoldResult> offsetValues =
3069 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3070 return b.getI64IntegerAttr(v);
3071 });
3072 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3073 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3074 SmallVector<OpFoldResult> strideValues =
3075 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3076 return b.getI64IntegerAttr(v);
3077 });
3078 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
3079}
3080
3081// Build a SubViewOp with dynamic entries and custom result type. If the
3082// type passed is nullptr, it is inferred.
3083void SubViewOp::build(OpBuilder &b, OperationState &result,
3084 MemRefType resultType, Value source,
3085 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3086 ArrayRef<int64_t> strides,
3087 ArrayRef<NamedAttribute> attrs) {
3088 SmallVector<OpFoldResult> offsetValues =
3089 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3090 return b.getI64IntegerAttr(v);
3091 });
3092 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3093 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3094 SmallVector<OpFoldResult> strideValues =
3095 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3096 return b.getI64IntegerAttr(v);
3097 });
3098 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3099 attrs);
3100}
3101
3102// Build a SubViewOp with dynamic entries and custom result type. If the type
3103// passed is nullptr, it is inferred.
3104void SubViewOp::build(OpBuilder &b, OperationState &result,
3105 MemRefType resultType, Value source, ValueRange offsets,
3106 ValueRange sizes, ValueRange strides,
3107 ArrayRef<NamedAttribute> attrs) {
3108 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3109 offsets, [](Value v) -> OpFoldResult { return v; });
3110 SmallVector<OpFoldResult> sizeValues =
3111 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3112 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3113 strides, [](Value v) -> OpFoldResult { return v; });
3114 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3115}
3116
3117// Build a SubViewOp with dynamic entries and inferred result type.
3118void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3119 ValueRange offsets, ValueRange sizes, ValueRange strides,
3120 ArrayRef<NamedAttribute> attrs) {
3121 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3122}
3123
3124/// For ViewLikeOpInterface.
3125Value SubViewOp::getViewSource() { return getSource(); }
3126
3127/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3128/// static value).
3129static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3130 int64_t t1Offset, t2Offset;
3131 SmallVector<int64_t> t1Strides, t2Strides;
3132 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3133 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3134 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3135}
3136
3137/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3138/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3139/// marked as dropped in `droppedDims`.
3140static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3141 const llvm::SmallBitVector &droppedDims) {
3142 assert(size_t(t1.getRank()) == droppedDims.size() &&
3143 "incorrect number of bits");
3144 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3145 "incorrect number of dropped dims");
3146 int64_t t1Offset, t2Offset;
3147 SmallVector<int64_t> t1Strides, t2Strides;
3148 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3149 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3150 if (failed(res1) || failed(res2))
3151 return false;
3152 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3153 if (droppedDims[i])
3154 continue;
3155 if (t1Strides[i] != t2Strides[j])
3156 return false;
3157 ++j;
3158 }
3159 return true;
3160}
3161
3163 SubViewOp op, Type expectedType) {
3164 auto memrefType = llvm::cast<ShapedType>(expectedType);
3165 switch (result) {
3167 return success();
3169 return op->emitError("expected result rank to be smaller or equal to ")
3170 << "the source rank, but got " << op.getType();
3172 return op->emitError("expected result type to be ")
3173 << expectedType
3174 << " or a rank-reduced version. (mismatch of result sizes), but got "
3175 << op.getType();
3177 return op->emitError("expected result element type to be ")
3178 << memrefType.getElementType() << ", but got " << op.getType();
3180 return op->emitError(
3181 "expected result and source memory spaces to match, but got ")
3182 << op.getType();
3184 return op->emitError("expected result type to be ")
3185 << expectedType
3186 << " or a rank-reduced version. (mismatch of result layout), but "
3187 "got "
3188 << op.getType();
3189 }
3190 llvm_unreachable("unexpected subview verification result");
3191}
3192
3193/// Verifier for SubViewOp.
3194LogicalResult SubViewOp::verify() {
3195 MemRefType baseType = getSourceType();
3196 MemRefType subViewType = getType();
3197 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3198 ArrayRef<int64_t> staticSizes = getStaticSizes();
3199 ArrayRef<int64_t> staticStrides = getStaticStrides();
3200
3201 // The base memref and the view memref should be in the same memory space.
3202 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3203 return emitError("different memory spaces specified for base memref "
3204 "type ")
3205 << baseType << " and subview memref type " << subViewType;
3206
3207 // Verify that the base memref type has a strided layout map.
3208 if (!baseType.isStrided())
3209 return emitError("base type ") << baseType << " is not strided";
3210
3211 // Compute the expected result type, assuming that there are no rank
3212 // reductions.
3213 MemRefType expectedType = SubViewOp::inferResultType(
3214 baseType, staticOffsets, staticSizes, staticStrides);
3215
3216 // Verify all properties of a shaped type: rank, element type and dimension
3217 // sizes. This takes into account potential rank reductions.
3218 auto shapedTypeVerification = isRankReducedType(
3219 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3220 if (shapedTypeVerification != SliceVerificationResult::Success)
3221 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3222
3223 // Make sure that the memory space did not change.
3224 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3226 *this, expectedType);
3227
3228 // Verify the offset of the layout map.
3229 if (!haveCompatibleOffsets(expectedType, subViewType))
3231 *this, expectedType);
3232
3233 // The only thing that's left to verify now are the strides. First, compute
3234 // the unused dimensions due to rank reductions. We have to look at sizes and
3235 // strides to decide which dimensions were dropped. This function also
3236 // partially verifies strides in case of rank reductions.
3237 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3238 getMixedSizes());
3239 if (failed(unusedDims))
3241 *this, expectedType);
3242
3243 // Strides must match.
3244 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3246 *this, expectedType);
3247
3248 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3249 // to the base memref.
3250 SliceBoundsVerificationResult boundsResult =
3251 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3252 staticStrides, /*generateErrorMessage=*/true);
3253 if (!boundsResult.isValid)
3254 return getOperation()->emitError(boundsResult.errorMessage);
3255
3256 return success();
3257}
3258
3260 return os << "range " << range.offset << ":" << range.size << ":"
3261 << range.stride;
3262}
3263
3264/// Return the list of Range (i.e. offset, size, stride). Each Range
3265/// entry contains either the dynamic value or a ConstantIndexOp constructed
3266/// with `b` at location `loc`.
3267SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3268 OpBuilder &b, Location loc) {
3269 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3270 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3271 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3273 unsigned rank = ranks[0];
3274 res.reserve(rank);
3275 for (unsigned idx = 0; idx < rank; ++idx) {
3276 Value offset =
3277 op.isDynamicOffset(idx)
3278 ? op.getDynamicOffset(idx)
3279 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3280 Value size =
3281 op.isDynamicSize(idx)
3282 ? op.getDynamicSize(idx)
3283 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3284 Value stride =
3285 op.isDynamicStride(idx)
3286 ? op.getDynamicStride(idx)
3287 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3288 res.emplace_back(Range{offset, size, stride});
3289 }
3290 return res;
3291}
3292
3293/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3294/// to deduce the result type for the given `sourceType`. Additionally, reduce
3295/// the rank of the inferred result type if `currentResultType` is lower rank
3296/// than `currentSourceType`. Use this signature if `sourceType` is updated
3297/// together with the result type. In this case, it is important to compute
3298/// the dropped dimensions using `currentSourceType` whose strides align with
3299/// `currentResultType`.
3301 MemRefType currentResultType, MemRefType currentSourceType,
3302 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3303 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3304 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3305 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3306 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3307 currentSourceType, currentResultType, mixedSizes);
3308 if (failed(unusedDims))
3309 return nullptr;
3310
3311 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3312 SmallVector<int64_t> shape, strides;
3313 unsigned numDimsAfterReduction =
3314 nonRankReducedType.getRank() - unusedDims->count();
3315 shape.reserve(numDimsAfterReduction);
3316 strides.reserve(numDimsAfterReduction);
3317 for (const auto &[idx, size, stride] :
3318 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3319 nonRankReducedType.getShape(), layout.getStrides())) {
3320 if (unusedDims->test(idx))
3321 continue;
3322 shape.push_back(size);
3323 strides.push_back(stride);
3324 }
3325
3326 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3327 StridedLayoutAttr::get(sourceType.getContext(),
3328 layout.getOffset(), strides),
3329 nonRankReducedType.getMemorySpace());
3330}
3331
3333 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3334 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3335 unsigned rank = memrefType.getRank();
3336 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3338 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3339 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3340 targetShape, memrefType, offsets, sizes, strides);
3341 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3342 sizes, strides);
3343}
3344
3345FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3346 Value value,
3347 ArrayRef<int64_t> desiredShape) {
3348 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3349 assert(sourceMemrefType && "not a ranked memref type");
3350 auto sourceShape = sourceMemrefType.getShape();
3351 if (sourceShape.equals(desiredShape))
3352 return value;
3353 auto maybeRankReductionMask =
3354 mlir::computeRankReductionMask(sourceShape, desiredShape);
3355 if (!maybeRankReductionMask)
3356 return failure();
3357 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3358}
3359
3360/// Helper method to check if a `subview` operation is trivially a no-op. This
3361/// is the case if the all offsets are zero, all strides are 1, and the source
3362/// shape is same as the size of the subview. In such cases, the subview can
3363/// be folded into its source.
3364static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3365 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3366 return false;
3367
3368 auto mixedOffsets = subViewOp.getMixedOffsets();
3369 auto mixedSizes = subViewOp.getMixedSizes();
3370 auto mixedStrides = subViewOp.getMixedStrides();
3371
3372 // Check offsets are zero.
3373 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3374 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3375 return !intValue || intValue.value() != 0;
3376 }))
3377 return false;
3378
3379 // Check strides are one.
3380 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3381 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3382 return !intValue || intValue.value() != 1;
3383 }))
3384 return false;
3385
3386 // Check all size values are static and matches the (static) source shape.
3387 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3388 for (const auto &size : llvm::enumerate(mixedSizes)) {
3389 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3390 if (!intValue || *intValue != sourceShape[size.index()])
3391 return false;
3392 }
3393 // All conditions met. The `SubViewOp` is foldable as a no-op.
3394 return true;
3395}
3396
3397namespace {
3398/// Pattern to rewrite a subview op with MemRefCast arguments.
3399/// This essentially pushes memref.cast past its consuming subview when
3400/// `canFoldIntoConsumerOp` is true.
3401///
3402/// Example:
3403/// ```
3404/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3405/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3406/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3407/// ```
3408/// is rewritten into:
3409/// ```
3410/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3411/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3412/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3413/// ```
3414class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3415public:
3416 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3417
3418 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3419 PatternRewriter &rewriter) const override {
3420 // Any constant operand, just return to let SubViewOpConstantFolder kick
3421 // in.
3422 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3423 return matchPattern(operand, matchConstantIndex());
3424 }))
3425 return failure();
3426
3427 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3428 if (!castOp)
3429 return failure();
3430
3431 if (!CastOp::canFoldIntoConsumerOp(castOp))
3432 return failure();
3433
3434 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3435 // the MemRefCastOp source operand type to infer the result type and the
3436 // current SubViewOp source operand type to compute the dropped dimensions
3437 // if the operation is rank-reducing.
3438 auto resultType = getCanonicalSubViewResultType(
3439 subViewOp.getType(), subViewOp.getSourceType(),
3440 llvm::cast<MemRefType>(castOp.getSource().getType()),
3441 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3442 subViewOp.getMixedStrides());
3443 if (!resultType)
3444 return failure();
3445
3446 Value newSubView = SubViewOp::create(
3447 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3448 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3449 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3450 subViewOp.getStaticStrides());
3451 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3452 newSubView);
3453 return success();
3454 }
3455};
3456
3457/// Canonicalize subview ops that are no-ops. When the source shape is not
3458/// same as a result shape due to use of `affine_map`.
3459class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3460public:
3461 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3462
3463 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3464 PatternRewriter &rewriter) const override {
3465 if (!isTrivialSubViewOp(subViewOp))
3466 return failure();
3467 if (subViewOp.getSourceType() == subViewOp.getType()) {
3468 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3469 return success();
3470 }
3471 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3472 subViewOp.getSource());
3473 return success();
3474 }
3475};
3476} // namespace
3477
3478/// Return the canonical type of the result of a subview.
3480 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3481 ArrayRef<OpFoldResult> mixedSizes,
3482 ArrayRef<OpFoldResult> mixedStrides) {
3483 // Infer a memref type without taking into account any rank reductions.
3484 MemRefType resTy = SubViewOp::inferResultType(
3485 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3486 if (!resTy)
3487 return {};
3488 MemRefType nonReducedType = resTy;
3489
3490 // Directly return the non-rank reduced type if there are no dropped dims.
3491 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3492 if (droppedDims.none())
3493 return nonReducedType;
3494
3495 // Take the strides and offset from the non-rank reduced type.
3496 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3497
3498 // Drop dims from shape and strides.
3499 SmallVector<int64_t> targetShape;
3500 SmallVector<int64_t> targetStrides;
3501 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3502 if (droppedDims.test(i))
3503 continue;
3504 targetStrides.push_back(nonReducedStrides[i]);
3505 targetShape.push_back(nonReducedType.getDimSize(i));
3506 }
3507
3508 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3509 StridedLayoutAttr::get(nonReducedType.getContext(),
3510 offset, targetStrides),
3511 nonReducedType.getMemorySpace());
3512 }
3513};
3514
3515/// A canonicalizer wrapper to replace SubViewOps.
3517 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3518 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3519 }
3520};
3521
3522void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3523 MLIRContext *context) {
3524 results
3525 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3526 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3527 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3528}
3529
3530OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3531 MemRefType sourceMemrefType = getSource().getType();
3532 MemRefType resultMemrefType = getResult().getType();
3533 auto resultLayout =
3534 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3535
3536 if (resultMemrefType == sourceMemrefType &&
3537 resultMemrefType.hasStaticShape() &&
3538 (!resultLayout || resultLayout.hasStaticLayout())) {
3539 return getViewSource();
3540 }
3541
3542 // Fold subview(subview(x)), where both subviews have the same size and the
3543 // second subview's offsets are all zero. (I.e., the second subview is a
3544 // no-op.)
3545 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3546 auto srcSizes = srcSubview.getMixedSizes();
3547 auto sizes = getMixedSizes();
3548 auto offsets = getMixedOffsets();
3549 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3550 auto strides = getMixedStrides();
3551 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3552 bool allSizesSame = llvm::equal(sizes, srcSizes);
3553 if (allOffsetsZero && allStridesOne && allSizesSame &&
3554 resultMemrefType == sourceMemrefType)
3555 return getViewSource();
3556 }
3557
3558 return {};
3559}
3560
3561FailureOr<std::optional<SmallVector<Value>>>
3562SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3563 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3564}
3565
3566void SubViewOp::inferStridedMetadataRanges(
3567 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3568 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3569 auto isUninitialized =
3570 +[](IntegerValueRange range) { return range.isUninitialized(); };
3571
3572 // Bail early if any of the operands metadata is not ready:
3573 SmallVector<IntegerValueRange> offsetOperands =
3574 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3575 if (llvm::any_of(offsetOperands, isUninitialized))
3576 return;
3577
3578 SmallVector<IntegerValueRange> sizeOperands =
3579 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3580 if (llvm::any_of(sizeOperands, isUninitialized))
3581 return;
3582
3583 SmallVector<IntegerValueRange> stridesOperands =
3584 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3585 if (llvm::any_of(stridesOperands, isUninitialized))
3586 return;
3587
3588 StridedMetadataRange sourceRange =
3589 ranges[getSourceMutable().getOperandNumber()];
3590 if (sourceRange.isUninitialized())
3591 return;
3592
3593 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3594
3595 // Get the dropped dims.
3596 llvm::SmallBitVector droppedDims = getDroppedDims();
3597
3598 // Compute the new offset, strides and sizes.
3599 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3600 SmallVector<ConstantIntRanges> strides, sizes;
3601
3602 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3603 bool dropped = droppedDims.test(i);
3604 // Compute the new offset.
3605 ConstantIntRanges off =
3606 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3607 offset = intrange::inferAdd({offset, off});
3608
3609 // Skip dropped dimensions.
3610 if (dropped)
3611 continue;
3612 // Multiply the strides.
3613 strides.push_back(
3614 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3615 // Get the sizes.
3616 sizes.push_back(sizeOperands[i].getValue());
3617 }
3618
3619 setMetadata(getResult(),
3621 SmallVector<ConstantIntRanges>({std::move(offset)}),
3622 std::move(sizes), std::move(strides)));
3623}
3624
3625//===----------------------------------------------------------------------===//
3626// TransposeOp
3627//===----------------------------------------------------------------------===//
3628
3629void TransposeOp::getAsmResultNames(
3630 function_ref<void(Value, StringRef)> setNameFn) {
3631 setNameFn(getResult(), "transpose");
3632}
3633
3634/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3635static MemRefType inferTransposeResultType(MemRefType memRefType,
3636 AffineMap permutationMap) {
3637 auto originalSizes = memRefType.getShape();
3638 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3639 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3640
3641 // Compute permuted sizes and strides.
3642 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3643 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3644
3645 return MemRefType::Builder(memRefType)
3646 .setShape(sizes)
3647 .setLayout(
3648 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3649}
3650
3651void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3652 AffineMapAttr permutation,
3653 ArrayRef<NamedAttribute> attrs) {
3654 auto permutationMap = permutation.getValue();
3655 assert(permutationMap);
3656
3657 auto memRefType = llvm::cast<MemRefType>(in.getType());
3658 // Compute result type.
3659 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3660
3661 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3662 build(b, result, resultType, in, attrs);
3663}
3664
3665// transpose $in $permutation attr-dict : type($in) `to` type(results)
3666void TransposeOp::print(OpAsmPrinter &p) {
3667 p << " " << getIn() << " " << getPermutation();
3668 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3669 p << " : " << getIn().getType() << " to " << getType();
3670}
3671
3672ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3673 OpAsmParser::UnresolvedOperand in;
3674 AffineMap permutation;
3675 MemRefType srcType, dstType;
3676 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3677 parser.parseOptionalAttrDict(result.attributes) ||
3678 parser.parseColonType(srcType) ||
3679 parser.resolveOperand(in, srcType, result.operands) ||
3680 parser.parseKeywordType("to", dstType) ||
3681 parser.addTypeToList(dstType, result.types))
3682 return failure();
3683
3684 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3685 AffineMapAttr::get(permutation));
3686 return success();
3687}
3688
3689LogicalResult TransposeOp::verify() {
3690 if (!getPermutation().isPermutation())
3691 return emitOpError("expected a permutation map");
3692 if (getPermutation().getNumDims() != getIn().getType().getRank())
3693 return emitOpError("expected a permutation map of same rank as the input");
3694
3695 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3696 auto resultType = llvm::cast<MemRefType>(getType());
3697 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3698 .canonicalizeStridedLayout();
3699
3700 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3701 return emitOpError("result type ")
3702 << resultType
3703 << " is not equivalent to the canonical transposed input type "
3704 << canonicalResultType;
3705 return success();
3706}
3707
3708OpFoldResult TransposeOp::fold(FoldAdaptor) {
3709 // First check for identity permutation, we can fold it away if input and
3710 // result types are identical already.
3711 if (getPermutation().isIdentity() && getType() == getIn().getType())
3712 return getIn();
3713 // Fold two consecutive memref.transpose Ops into one by composing their
3714 // permutation maps.
3715 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3716 AffineMap composedPermutation =
3717 getPermutation().compose(otherTransposeOp.getPermutation());
3718 getInMutable().assign(otherTransposeOp.getIn());
3719 setPermutation(composedPermutation);
3720 return getResult();
3721 }
3722 return {};
3723}
3724
3725FailureOr<std::optional<SmallVector<Value>>>
3726TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3727 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3728}
3729
3730//===----------------------------------------------------------------------===//
3731// ViewOp
3732//===----------------------------------------------------------------------===//
3733
3734void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3735 setNameFn(getResult(), "view");
3736}
3737
3738LogicalResult ViewOp::verify() {
3739 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3740 auto viewType = getType();
3741
3742 // The base memref should have identity layout map (or none).
3743 if (!baseType.getLayout().isIdentity())
3744 return emitError("unsupported map for base memref type ") << baseType;
3745
3746 // The result memref should have identity layout map (or none).
3747 if (!viewType.getLayout().isIdentity())
3748 return emitError("unsupported map for result memref type ") << viewType;
3749
3750 // The base memref and the view memref should be in the same memory space.
3751 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3752 return emitError("different memory spaces specified for base memref "
3753 "type ")
3754 << baseType << " and view memref type " << viewType;
3755
3756 // Verify that we have the correct number of sizes for the result type.
3757 if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
3758 return failure();
3759
3760 return success();
3761}
3762
3763Value ViewOp::getViewSource() { return getSource(); }
3764
3765OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3766 MemRefType sourceMemrefType = getSource().getType();
3767 MemRefType resultMemrefType = getResult().getType();
3768
3769 if (resultMemrefType == sourceMemrefType &&
3770 resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
3771 return getViewSource();
3772
3773 return {};
3774}
3775
3776SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3777 SmallVector<OpFoldResult> result;
3778 unsigned ctr = 0;
3779 Builder b(getContext());
3780 for (int64_t dim : getType().getShape()) {
3781 if (ShapedType::isDynamic(dim)) {
3782 result.push_back(getSizes()[ctr++]);
3783 } else {
3784 result.push_back(b.getIndexAttr(dim));
3785 }
3786 }
3787 return result;
3788}
3789
3790namespace {
3791/// Given a memref type and a range of values that defines its dynamic
3792/// dimension sizes, turn all dynamic sizes that have a constant value into
3793/// static dimension sizes.
3794static MemRefType
3795foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
3796 SmallVectorImpl<Value> &foldedDynamicSizes) {
3797 SmallVector<int64_t> staticShape(type.getShape());
3798 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3799 "incorrect number of dynamic sizes");
3800
3801 // Compute new static and dynamic sizes.
3802 unsigned ctr = 0;
3803 for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3804 if (ShapedType::isStatic(dimSize))
3805 continue;
3806
3807 Value dynamicSize = dynamicSizes[ctr++];
3808 if (auto cst = getConstantIntValue(dynamicSize)) {
3809 // Dynamic size must be non-negative.
3810 if (cst.value() < 0) {
3811 foldedDynamicSizes.push_back(dynamicSize);
3812 continue;
3813 }
3814 staticShape[dim] = cst.value();
3815 } else {
3816 foldedDynamicSizes.push_back(dynamicSize);
3817 }
3818 }
3819
3820 return MemRefType::Builder(type).setShape(staticShape);
3821}
3822
3823/// Change the result type of a `memref.view` by making originally dynamic
3824/// dimensions static when their sizes come from `constant` ops.
3825/// Example:
3826/// ```
3827/// %c5 = arith.constant 5: index
3828/// %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
3829/// ```
3830/// to
3831/// ```
3832/// %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
3833/// ```
3834struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3835 using Base::Base;
3836
3837 LogicalResult matchAndRewrite(ViewOp viewOp,
3838 PatternRewriter &rewriter) const override {
3839 SmallVector<Value> foldedDynamicSizes;
3840 MemRefType resultType = viewOp.getType();
3841 MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
3842 resultType, viewOp.getSizes(), foldedDynamicSizes);
3843
3844 // Stop here if no dynamic size was promoted to static.
3845 if (foldedMemRefType == resultType)
3846 return failure();
3847
3848 // Create new ViewOp.
3849 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
3850 viewOp.getSource(), viewOp.getByteShift(),
3851 foldedDynamicSizes);
3852 // Insert a cast so we have the same type as the old memref type.
3853 rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
3854 return success();
3855 }
3856};
3857
3858/// view(memref.cast(%source)) -> view(%source).
3859struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3860 using Base::Base;
3861
3862 LogicalResult matchAndRewrite(ViewOp viewOp,
3863 PatternRewriter &rewriter) const override {
3864 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3865 if (!memrefCastOp)
3866 return failure();
3867
3868 rewriter.replaceOpWithNewOp<ViewOp>(
3869 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3870 viewOp.getByteShift(), viewOp.getSizes());
3871 return success();
3872 }
3873};
3874} // namespace
3875
3876void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3877 MLIRContext *context) {
3878 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3879}
3880
3881FailureOr<std::optional<SmallVector<Value>>>
3882ViewOp::bubbleDownCasts(OpBuilder &builder) {
3883 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3884}
3885
3886//===----------------------------------------------------------------------===//
3887// AtomicRMWOp
3888//===----------------------------------------------------------------------===//
3889
3890LogicalResult AtomicRMWOp::verify() {
3891 if (getMemRefType().getRank() != getNumOperands() - 2)
3892 return emitOpError(
3893 "expects the number of subscripts to be equal to memref rank");
3894 switch (getKind()) {
3895 case arith::AtomicRMWKind::addf:
3896 case arith::AtomicRMWKind::maximumf:
3897 case arith::AtomicRMWKind::minimumf:
3898 case arith::AtomicRMWKind::mulf:
3899 if (!llvm::isa<FloatType>(getValue().getType()))
3900 return emitOpError() << "with kind '"
3901 << arith::stringifyAtomicRMWKind(getKind())
3902 << "' expects a floating-point type";
3903 break;
3904 case arith::AtomicRMWKind::addi:
3905 case arith::AtomicRMWKind::maxs:
3906 case arith::AtomicRMWKind::maxu:
3907 case arith::AtomicRMWKind::mins:
3908 case arith::AtomicRMWKind::minu:
3909 case arith::AtomicRMWKind::muli:
3910 case arith::AtomicRMWKind::ori:
3911 case arith::AtomicRMWKind::xori:
3912 case arith::AtomicRMWKind::andi:
3913 if (!llvm::isa<IntegerType>(getValue().getType()))
3914 return emitOpError() << "with kind '"
3915 << arith::stringifyAtomicRMWKind(getKind())
3916 << "' expects an integer type";
3917 break;
3918 default:
3919 break;
3920 }
3921 return success();
3922}
3923
3924OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3925 /// atomicrmw(memrefcast) -> atomicrmw
3926 if (succeeded(foldMemRefCast(*this, getValue())))
3927 return getResult();
3928 return OpFoldResult();
3929}
3930
3931FailureOr<std::optional<SmallVector<Value>>>
3932AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3934 getResult());
3935}
3936
3937//===----------------------------------------------------------------------===//
3938// TableGen'd op method definitions
3939//===----------------------------------------------------------------------===//
3940
3941#define GET_OP_CLASSES
3942#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:69
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:98
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IndexType getIndexType()
Definition Builders.cpp:55
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:578
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:24
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.