MLIR 23.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27
28using namespace mlir;
29using namespace mlir::memref;
30
31/// Materialize a single constant operation from a given attribute value with
32/// the desired resultant type.
33Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
34 Attribute value, Type type,
35 Location loc) {
36 return arith::ConstantOp::materialize(builder, value, type, loc);
37}
38
39//===----------------------------------------------------------------------===//
40// Common canonicalization pattern support logic
41//===----------------------------------------------------------------------===//
42
43/// This is a common class used for patterns of the form
44/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
45/// into the root operation directly.
46LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
47 bool folded = false;
48 for (OpOperand &operand : op->getOpOperands()) {
49 auto cast = operand.get().getDefiningOp<CastOp>();
50 if (cast && operand.get() != inner &&
51 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
52 operand.set(cast.getOperand());
53 folded = true;
54 }
55 }
56 return success(folded);
57}
58
59/// Return an unranked/ranked tensor type for the given unranked/ranked memref
60/// type.
62 if (auto memref = llvm::dyn_cast<MemRefType>(type))
63 return RankedTensorType::get(memref.getShape(), memref.getElementType());
64 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
65 return UnrankedTensorType::get(memref.getElementType());
66 return NoneType::get(type.getContext());
67}
68
70 int64_t dim) {
71 auto memrefType = llvm::cast<MemRefType>(value.getType());
72 if (memrefType.isDynamicDim(dim))
73 return builder.createOrFold<memref::DimOp>(loc, value, dim);
74
75 return builder.getIndexAttr(memrefType.getDimSize(dim));
76}
77
79 Location loc, Value value) {
80 auto memrefType = llvm::cast<MemRefType>(value.getType());
82 for (int64_t i = 0; i < memrefType.getRank(); ++i)
83 result.push_back(getMixedSize(builder, loc, value, i));
84 return result;
85}
86
87//===----------------------------------------------------------------------===//
88// Utility functions for propagating static information
89//===----------------------------------------------------------------------===//
90
91/// Helper function that sets values[i] to constValues[i] if the latter is a
92/// static value, as indicated by ShapedType::kDynamic.
93///
94/// If constValues[i] is dynamic, tries to extract a constant value from
95/// value[i] to allow for additional folding opportunities. Also convertes all
96/// existing attributes to index attributes. (They may be i64 attributes.)
98 ArrayRef<int64_t> constValues) {
99 assert(constValues.size() == values.size() &&
100 "incorrect number of const values");
101 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
102 Builder builder(values[i].getContext());
103 if (ShapedType::isStatic(cstVal)) {
104 // Constant value is known, use it directly.
105 values[i] = builder.getIndexAttr(cstVal);
106 continue;
107 }
108 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
109 // Try to extract a constant or convert an existing to index.
110 values[i] = builder.getIndexAttr(*cst);
111 }
112 }
113}
114
115/// Helper function to retrieve a lossless memory-space cast, and the
116/// corresponding new result memref type.
117static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
119 MemorySpaceCastOpInterface castOp =
120 MemorySpaceCastOpInterface::getIfPromotableCast(src);
121
122 // Bail if the cast is not lossless.
123 if (!castOp)
124 return {};
125
126 // Transform the source and target type of `castOp` to have the same metadata
127 // as `resultTy`. Bail if not possible.
128 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
129 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
130 if (failed(srcTy))
131 return {};
132
133 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
134 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
135 if (failed(tgtTy))
136 return {};
137
138 // Check if this is a valid memory-space cast.
139 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
140 return {};
141
142 return std::make_tuple(castOp, *tgtTy, *srcTy);
143}
144
145/// Implementation of `bubbleDownCasts` method for memref operations that
146/// return a single memref result.
147template <typename ConcreteOpTy>
148static FailureOr<std::optional<SmallVector<Value>>>
150 OpOperand &src) {
151 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
152 // Bail if we cannot cast.
153 if (!castOp)
154 return failure();
155
156 // Create the new operands.
157 SmallVector<Value> operands;
158 llvm::append_range(operands, op->getOperands());
159 operands[src.getOperandNumber()] = castOp.getSourcePtr();
160
161 // Create the new op and results.
162 auto newOp = ConcreteOpTy::create(
163 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
164 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
165
166 // Insert a memory-space cast to the original memory space of the op.
167 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
168 builder, tgtTy,
169 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
170 return std::optional<SmallVector<Value>>(
171 SmallVector<Value>({result.getTargetPtr()}));
172}
173
174//===----------------------------------------------------------------------===//
175// AllocOp / AllocaOp
176//===----------------------------------------------------------------------===//
177
178void AllocOp::getAsmResultNames(
179 function_ref<void(Value, StringRef)> setNameFn) {
180 setNameFn(getResult(), "alloc");
181}
182
183void AllocaOp::getAsmResultNames(
184 function_ref<void(Value, StringRef)> setNameFn) {
185 setNameFn(getResult(), "alloca");
186}
187
188template <typename AllocLikeOp>
189static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
190 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
191 "applies to only alloc or alloca");
192 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
193 if (!memRefType)
194 return op.emitOpError("result must be a memref");
195
196 if (failed(verifyDynamicDimensionCount(op, memRefType, op.getDynamicSizes())))
197 return failure();
198
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError("symbol operand count does not equal memref symbol "
204 "count: expected ")
205 << numSymbols << ", got " << op.getSymbolOperands().size();
206
207 return success();
208}
209
210LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
211
212LogicalResult AllocaOp::verify() {
213 // An alloca op needs to have an ancestor with an allocation scope trait.
214 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
215 return emitOpError(
216 "requires an ancestor op with AutomaticAllocationScope trait");
217
218 return verifyAllocLikeOp(*this);
219}
220
221namespace {
222/// Fold constant dimensions into an alloc like operation.
223template <typename AllocLikeOp>
224struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
225 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
226
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
228 PatternRewriter &rewriter) const override {
229 // Check to see if any dimensions operands are constants. If so, we can
230 // substitute and drop them.
231 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
232 APInt constSizeArg;
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
234 return false;
235 return constSizeArg.isNonNegative();
236 }))
237 return failure();
238
239 auto memrefType = alloc.getType();
240
241 // Ok, we have one or more constant operands. Collect the non-constant ones
242 // and keep track of the resultant memref type to build.
243 SmallVector<int64_t, 4> newShapeConstants;
244 newShapeConstants.reserve(memrefType.getRank());
245 SmallVector<Value, 4> dynamicSizes;
246
247 unsigned dynamicDimPos = 0;
248 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
250 // If this is already static dimension, keep it.
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
253 continue;
254 }
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
256 APInt constSizeArg;
257 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
258 constSizeArg.isNonNegative()) {
259 // Dynamic shape dimension will be folded.
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
261 } else {
262 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
265 }
266 dynamicDimPos++;
267 }
268
269 // Create new memref type (which will have fewer dynamic dimensions).
270 MemRefType newMemRefType =
271 MemRefType::Builder(memrefType).setShape(newShapeConstants);
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
273
274 // Create and insert the alloc op for the new memref.
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
278 // Insert a cast so we have the same type as the old alloc.
279 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
280 return success();
281 }
282};
283
284/// Fold alloc operations with no users or only store and dealloc uses.
285template <typename T>
286struct SimplifyDeadAlloc : public OpRewritePattern<T> {
287 using OpRewritePattern<T>::OpRewritePattern;
288
289 LogicalResult matchAndRewrite(T alloc,
290 PatternRewriter &rewriter) const override {
291 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
295 }))
296 return failure();
297
298 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
299 rewriter.eraseOp(user);
300
301 rewriter.eraseOp(alloc);
302 return success();
303 }
304};
305} // namespace
306
307void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
308 MLIRContext *context) {
309 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
310}
311
312void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
313 MLIRContext *context) {
314 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
315 context);
316}
317
318//===----------------------------------------------------------------------===//
319// ReallocOp
320//===----------------------------------------------------------------------===//
321
322LogicalResult ReallocOp::verify() {
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
324 MemRefType resultType = getType();
325
326 // The source memref should have identity layout (or none).
327 if (!sourceType.getLayout().isIdentity())
328 return emitError("unsupported layout for source memref type ")
329 << sourceType;
330
331 // The result memref should have identity layout (or none).
332 if (!resultType.getLayout().isIdentity())
333 return emitError("unsupported layout for result memref type ")
334 << resultType;
335
336 // The source memref and the result memref should be in the same memory space.
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError("different memory spaces specified for source memref "
339 "type ")
340 << sourceType << " and result memref type " << resultType;
341
342 // The source memref and the result memref should have the same element type.
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError("different element types specified for source memref "
345 "type ")
346 << sourceType << " and result memref type " << resultType;
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor::parent());
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) {
416 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
417}
418
419/// Given an operation, return whether this op is guaranteed to
420/// allocate an AutomaticAllocationScopeResource
422 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
423 if (!interface)
424 return false;
425 for (auto res : op->getResults()) {
426 if (auto effect =
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
430 return true;
431 }
432 }
433 return false;
434}
435
436/// Given an operation, return whether this op itself could
437/// allocate an AutomaticAllocationScopeResource. Note that
438/// this will not check whether an operation contained within
439/// the op can allocate.
441 // This op itself doesn't create a stack allocation,
442 // the inner allocation should be handled separately.
444 return false;
445 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
446 if (!interface)
447 return true;
448 for (auto res : op->getResults()) {
449 if (auto effect =
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
453 return true;
454 }
455 }
456 return false;
457}
458
459/// Return whether this op is the last non terminating op
460/// in a region. That is to say, it is in a one-block region
461/// and is only followed by a terminator. This prevents
462/// extending the lifetime of allocations.
464 return op->getBlock()->mightHaveTerminator() &&
465 op->getNextNode() == op->getBlock()->getTerminator() &&
467}
468
469/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
470/// or it contains no allocation.
471struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
472 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
473
474 LogicalResult matchAndRewrite(AllocaScopeOp op,
475 PatternRewriter &rewriter) const override {
476 bool hasPotentialAlloca =
477 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
478 if (alloc == op)
479 return WalkResult::advance();
481 return WalkResult::interrupt();
482 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
483 return WalkResult::skip();
484 return WalkResult::advance();
485 }).wasInterrupted();
486
487 // If this contains no potential allocation, it is always legal to
488 // inline. Otherwise, consider two conditions:
489 if (hasPotentialAlloca) {
490 // If the parent isn't an allocation scope, or we are not the last
491 // non-terminator op in the parent, we will extend the lifetime.
492 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return failure();
495 return failure();
496 }
497
498 Block *block = &op.getRegion().front();
499 Operation *terminator = block->getTerminator();
500 ValueRange results = terminator->getOperands();
501 rewriter.inlineBlockBefore(block, op);
502 rewriter.replaceOp(op, results);
503 rewriter.eraseOp(terminator);
504 return success();
505 }
506};
507
508/// Move allocations into an allocation scope, if it is legal to
509/// move them (e.g. their operands are available at the location
510/// the op would be moved to).
511struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
512 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
513
514 LogicalResult matchAndRewrite(AllocaScopeOp op,
515 PatternRewriter &rewriter) const override {
516
517 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
518 return failure();
519
520 Operation *lastParentWithoutScope = op->getParentOp();
521
522 if (!lastParentWithoutScope ||
523 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
524 return failure();
525
526 // Only apply to if this is this last non-terminator
527 // op in the block (lest lifetime be extended) of a one
528 // block region
529 if (!lastNonTerminatorInRegion(op) ||
530 !lastNonTerminatorInRegion(lastParentWithoutScope))
531 return failure();
532
533 while (!lastParentWithoutScope->getParentOp()
535 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
536 if (!lastParentWithoutScope ||
537 !lastNonTerminatorInRegion(lastParentWithoutScope))
538 return failure();
539 }
540 assert(lastParentWithoutScope->getParentOp()
542
543 Region *containingRegion = nullptr;
544 for (auto &r : lastParentWithoutScope->getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion == nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
549 }
550 }
551 assert(containingRegion && "op must be contained in a region");
552
554 op->walk([&](Operation *alloc) {
556 return WalkResult::skip();
557
558 // If any operand is not defined before the location of
559 // lastParentWithoutScope (i.e. where we would hoist to), skip.
560 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
561 return containingRegion->isAncestor(v.getParentRegion());
562 }))
563 return WalkResult::skip();
564 toHoist.push_back(alloc);
565 return WalkResult::advance();
566 });
567
568 if (toHoist.empty())
569 return failure();
570 rewriter.setInsertionPoint(lastParentWithoutScope);
571 for (auto *op : toHoist) {
572 auto *cloned = rewriter.clone(*op);
573 rewriter.replaceOp(op, cloned->getResults());
574 }
575 return success();
576 }
577};
578
579void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
580 MLIRContext *context) {
581 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
582}
583
584//===----------------------------------------------------------------------===//
585// AssumeAlignmentOp
586//===----------------------------------------------------------------------===//
587
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError("alignment must be power of 2");
591 return success();
592}
593
594void AssumeAlignmentOp::getAsmResultNames(
595 function_ref<void(Value, StringRef)> setNameFn) {
596 setNameFn(getResult(), "assume_align");
597}
598
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
601 if (!source)
602 return {};
603 if (source.getAlignment() != getAlignment())
604 return {};
605 return getMemref();
606}
607
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
610 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
611}
612
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
614 int resultIndex,
615 int dim) {
616 assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
618}
619
620//===----------------------------------------------------------------------===//
621// DistinctObjectsOp
622//===----------------------------------------------------------------------===//
623
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError("operand types and result types must match");
627
628 if (getOperandTypes().empty())
629 return emitOpError("expected at least one operand");
630
631 return success();
632}
633
634LogicalResult DistinctObjectsOp::inferReturnTypes(
635 MLIRContext * /*context*/, std::optional<Location> /*location*/,
636 ValueRange operands, DictionaryAttr /*attributes*/,
637 OpaqueProperties /*properties*/, RegionRange /*regions*/,
638 SmallVectorImpl<Type> &inferredReturnTypes) {
639 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// CastOp
645//===----------------------------------------------------------------------===//
646
647void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
648 setNameFn(getResult(), "cast");
649}
650
651/// Determines whether MemRef_CastOp casts to a more dynamic version of the
652/// source memref. This is useful to fold a memref.cast into a consuming op
653/// and implement canonicalization patterns for ops in different dialects that
654/// may consume the results of memref.cast operations. Such foldable memref.cast
655/// operations are typically inserted as `view` and `subview` ops are
656/// canonicalized, to preserve the type compatibility of their uses.
657///
658/// Returns true when all conditions are met:
659/// 1. source and result are ranked memrefs with strided semantics and same
660/// element type and rank.
661/// 2. each of the source's size, offset or stride has more static information
662/// than the corresponding result's size, offset or stride.
663///
664/// Example 1:
665/// ```mlir
666/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
667/// %2 = consumer %1 ... : memref<?x?xf32> ...
668/// ```
669///
670/// may fold into:
671///
672/// ```mlir
673/// %2 = consumer %0 ... : memref<8x16xf32> ...
674/// ```
675///
676/// Example 2:
677/// ```
678/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
679/// to memref<?x?xf32>
680/// consumer %1 : memref<?x?xf32> ...
681/// ```
682///
683/// may fold into:
684///
685/// ```
686/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
687/// ```
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692
693 // Requires ranked MemRefType.
694 if (!sourceType || !resultType)
695 return false;
696
697 // Requires same elemental type.
698 if (sourceType.getElementType() != resultType.getElementType())
699 return false;
700
701 // Requires same rank.
702 if (sourceType.getRank() != resultType.getRank())
703 return false;
704
705 // Only fold casts between strided memref forms.
706 int64_t sourceOffset, resultOffset;
707 SmallVector<int64_t, 4> sourceStrides, resultStrides;
708 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
710 return false;
711
712 // If cast is towards more static sizes along any dimension, don't fold.
713 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
715 if (ss != st)
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
717 return false;
718 }
719
720 // If cast is towards more static offset along any dimension, don't fold.
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
724 return false;
725
726 // If cast is towards more static strides along any dimension, don't fold.
727 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
729 if (ss != st)
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
731 return false;
732 }
733
734 return true;
735}
736
737bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
738 if (inputs.size() != 1 || outputs.size() != 1)
739 return false;
740 Type a = inputs.front(), b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(b);
743
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
746
747 if (aT && bT) {
748 if (aT.getElementType() != bT.getElementType())
749 return false;
750 if (aT.getLayout() != bT.getLayout()) {
751 int64_t aOffset, bOffset;
752 SmallVector<int64_t, 4> aStrides, bStrides;
753 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
756 return false;
757
758 // Strides along a dimension/offset are compatible if the value in the
759 // source memref is static and the value in the target memref is the
760 // same. They are also compatible if either one is dynamic (see
761 // description of MemRefCastOp for details).
762 // Note that for dimensions of size 1, the stride can differ.
763 auto checkCompatible = [](int64_t a, int64_t b) {
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
765 };
766 if (!checkCompatible(aOffset, bOffset))
767 return false;
768 for (const auto &[index, aStride] : enumerate(aStrides)) {
769 if (aT.getDimSize(index) == 1)
770 continue;
771 if (!checkCompatible(aStride, bStrides[index]))
772 return false;
773 }
774 }
775 if (aT.getMemorySpace() != bT.getMemorySpace())
776 return false;
777
778 // They must have the same rank, and any specified dimensions must match.
779 if (aT.getRank() != bT.getRank())
780 return false;
781
782 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
785 aDim != bDim)
786 return false;
787 }
788 return true;
789 } else {
790 if (!aT && !uaT)
791 return false;
792 if (!bT && !ubT)
793 return false;
794 // Unranked to unranked casting is unsupported
795 if (uaT && ubT)
796 return false;
797
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
801 return false;
802
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
806 }
807
808 return false;
809}
810
811OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
812 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
813}
814
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(OpBuilder &builder) {
817 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
818}
819
820//===----------------------------------------------------------------------===//
821// CopyOp
822//===----------------------------------------------------------------------===//
823
824namespace {
825
826/// Fold memref.copy(%x, %x).
827struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
829
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter) const override {
832 if (copyOp.getSource() != copyOp.getTarget())
833 return failure();
834
835 rewriter.eraseOp(copyOp);
836 return success();
837 }
838};
839
840struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
842
843 static bool isEmptyMemRef(BaseMemRefType type) {
844 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
845 }
846
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter) const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
851 rewriter.eraseOp(copyOp);
852 return success();
853 }
854
855 return failure();
856 }
857};
858} // namespace
859
860void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
861 MLIRContext *context) {
862 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
863}
864
865/// If the source/target of a CopyOp is a CastOp that does not modify the shape
866/// and element type, the cast can be skipped. Such CastOps only cast the layout
867/// of the type.
868static LogicalResult foldCopyOfCast(CopyOp op) {
869 for (OpOperand &operand : op->getOpOperands()) {
870 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
873 return success();
874 }
875 }
876 return failure();
877}
878
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
881
882 /// copy(memrefcast) -> copy
883 return foldCopyOfCast(*this);
884}
885
886//===----------------------------------------------------------------------===//
887// DeallocOp
888//===----------------------------------------------------------------------===//
889
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
892 /// dealloc(memrefcast) -> dealloc
893 return foldMemRefCast(*this);
894}
895
896//===----------------------------------------------------------------------===//
897// DimOp
898//===----------------------------------------------------------------------===//
899
900void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(), "dim");
902}
903
904void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
905 int64_t index) {
906 auto loc = result.location;
907 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
908 build(builder, result, source, indexValue);
909}
910
911std::optional<int64_t> DimOp::getConstantIndex() {
913}
914
915Speculation::Speculatability DimOp::getSpeculatability() {
916 auto constantIndex = getConstantIndex();
917 if (!constantIndex)
919
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
921 if (!rankedSourceType)
923
924 if (rankedSourceType.getRank() <= constantIndex)
926
928}
929
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
931 SetIntLatticeFn setResultRange) {
932 setResultRange(getResult(),
933 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
934}
935
936/// Return a map with key being elements in `vals` and data being number of
937/// occurences of it. Use std::map, since the `vals` here are strides and the
938/// dynamic stride value is the same as the tombstone value for
939/// `DenseMap<int64_t>`.
940static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
941 std::map<int64_t, unsigned> numOccurences;
942 for (auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
945}
946
947/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
948/// to be a subset of `originalType` with some `1` entries erased, return the
949/// set of indices that specifies which of the entries of `originalShape` are
950/// dropped to obtain `reducedShape`.
951/// This accounts for cases where there are multiple unit-dims, but only a
952/// subset of those are dropped. For MemRefTypes these can be disambiguated
953/// using the strides. If a dimension is dropped the stride must be dropped too.
954static FailureOr<llvm::SmallBitVector>
955computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
957 llvm::SmallBitVector unusedDims(originalType.getRank());
958 if (originalType.getRank() == reducedType.getRank())
959 return unusedDims;
960
961 for (const auto &dim : llvm::enumerate(sizes))
962 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
963 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
964 unusedDims.set(dim.index());
965
966 // Early exit for the case where the number of unused dims matches the number
967 // of ranks reduced.
968 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
969 originalType.getRank())
970 return unusedDims;
971
972 SmallVector<int64_t> originalStrides, candidateStrides;
973 int64_t originalOffset, candidateOffset;
974 if (failed(
975 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
976 failed(
977 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
978 return failure();
979
980 // For memrefs, a dimension is truly dropped if its corresponding stride is
981 // also dropped. This is particularly important when more than one of the dims
982 // is 1. Track the number of occurences of the strides in the original type
983 // and the candidate type. For each unused dim that stride should not be
984 // present in the candidate type. Note that there could be multiple dimensions
985 // that have the same size. We dont need to exactly figure out which dim
986 // corresponds to which stride, we just need to verify that the number of
987 // reptitions of a stride in the original + number of unused dims with that
988 // stride == number of repititions of a stride in the candidate.
989 std::map<int64_t, unsigned> currUnaccountedStrides =
990 getNumOccurences(originalStrides);
991 std::map<int64_t, unsigned> candidateStridesNumOccurences =
992 getNumOccurences(candidateStrides);
993 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
994 if (!unusedDims.test(dim))
995 continue;
996 int64_t originalStride = originalStrides[dim];
997 if (currUnaccountedStrides[originalStride] >
998 candidateStridesNumOccurences[originalStride]) {
999 // This dim can be treated as dropped.
1000 currUnaccountedStrides[originalStride]--;
1001 continue;
1002 }
1003 if (currUnaccountedStrides[originalStride] ==
1004 candidateStridesNumOccurences[originalStride]) {
1005 // The stride for this is not dropped. Keep as is.
1006 unusedDims.reset(dim);
1007 continue;
1008 }
1009 if (currUnaccountedStrides[originalStride] <
1010 candidateStridesNumOccurences[originalStride]) {
1011 // This should never happen. Cant have a stride in the reduced rank type
1012 // that wasnt in the original one.
1013 return failure();
1014 }
1015 }
1016
1017 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1018 originalType.getRank())
1019 return failure();
1020 return unusedDims;
1021}
1022
1023llvm::SmallBitVector SubViewOp::getDroppedDims() {
1024 MemRefType sourceType = getSourceType();
1025 MemRefType resultType = getType();
1026 FailureOr<llvm::SmallBitVector> unusedDims =
1027 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1028 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1029 return *unusedDims;
1030}
1031
1032OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1033 // All forms of folding require a known index.
1034 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1035 if (!index)
1036 return {};
1037
1038 // Folding for unranked types (UnrankedMemRefType) is not supported.
1039 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1040 if (!memrefType)
1041 return {};
1042
1043 // Out of bound indices produce undefined behavior but are still valid IR.
1044 // Don't choke on them.
1045 int64_t indexVal = index.getInt();
1046 if (indexVal < 0 || indexVal >= memrefType.getRank())
1047 return {};
1048
1049 // Fold if the shape extent along the given index is known.
1050 if (!memrefType.isDynamicDim(index.getInt())) {
1051 Builder builder(getContext());
1052 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1053 }
1054
1055 // The size at the given index is now known to be a dynamic size.
1056 unsigned unsignedIndex = index.getValue().getZExtValue();
1057
1058 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1059 Operation *definingOp = getSource().getDefiningOp();
1060
1061 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1062 return *(alloc.getDynamicSizes().begin() +
1063 memrefType.getDynamicDimIndex(unsignedIndex));
1064
1065 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1066 return *(alloca.getDynamicSizes().begin() +
1067 memrefType.getDynamicDimIndex(unsignedIndex));
1068
1069 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1070 return *(view.getDynamicSizes().begin() +
1071 memrefType.getDynamicDimIndex(unsignedIndex));
1072
1073 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1074 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1075 unsigned resultIndex = 0;
1076 unsigned sourceRank = subview.getSourceType().getRank();
1077 unsigned sourceIndex = 0;
1078 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1079 if (unusedDims.test(i))
1080 continue;
1081 if (resultIndex == unsignedIndex) {
1082 sourceIndex = i;
1083 break;
1084 }
1085 resultIndex++;
1086 }
1087 assert(subview.isDynamicSize(sourceIndex) &&
1088 "expected dynamic subview size");
1089 return subview.getDynamicSize(sourceIndex);
1090 }
1091
1092 // dim(memrefcast) -> dim
1093 if (succeeded(foldMemRefCast(*this)))
1094 return getResult();
1095
1096 return {};
1097}
1098
1099namespace {
1100/// Fold dim of a memref reshape operation to a load into the reshape's shape
1101/// operand.
1102struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1103 using OpRewritePattern<DimOp>::OpRewritePattern;
1104
1105 LogicalResult matchAndRewrite(DimOp dim,
1106 PatternRewriter &rewriter) const override {
1107 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1108
1109 if (!reshape)
1110 return rewriter.notifyMatchFailure(
1111 dim, "Dim op is not defined by a reshape op.");
1112
1113 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1114 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1115 // cheaply check that either of the following conditions hold:
1116 // 1. dim.getIndex() is defined in the same block as reshape but before
1117 // reshape.
1118 // 2. dim.getIndex() is defined in a parent block of
1119 // reshape.
1120
1121 // Check condition 1
1122 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1123 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1124 if (reshape->isBeforeInBlock(definingOp)) {
1125 return rewriter.notifyMatchFailure(
1126 dim,
1127 "dim.getIndex is not defined before reshape in the same block.");
1128 }
1129 } // else dim.getIndex is a block argument to reshape->getBlock and
1130 // dominates reshape
1131 } // Check condition 2
1132 else if (dim->getBlock() != reshape->getBlock() &&
1133 !dim.getIndex().getParentRegion()->isProperAncestor(
1134 reshape->getParentRegion())) {
1135 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1136 // already know dim.getIndex() dominates reshape without calling
1137 // `isProperAncestor`
1138 return rewriter.notifyMatchFailure(
1139 dim, "dim.getIndex does not dominate reshape.");
1140 }
1141
1142 // Place the load directly after the reshape to ensure that the shape memref
1143 // was not mutated.
1144 rewriter.setInsertionPointAfter(reshape);
1145 Location loc = dim.getLoc();
1146 Value load =
1147 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1148 if (load.getType() != dim.getType())
1149 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1150 rewriter.replaceOp(dim, load);
1151 return success();
1152 }
1153};
1154
1155} // namespace
1156
1157void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1158 MLIRContext *context) {
1159 results.add<DimOfMemRefReshape>(context);
1160}
1161
1162// ---------------------------------------------------------------------------
1163// DmaStartOp
1164// ---------------------------------------------------------------------------
1165
1166void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1167 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1168 ValueRange destIndices, Value numElements,
1169 Value tagMemRef, ValueRange tagIndices, Value stride,
1170 Value elementsPerStride) {
1171 result.addOperands(srcMemRef);
1172 result.addOperands(srcIndices);
1173 result.addOperands(destMemRef);
1174 result.addOperands(destIndices);
1175 result.addOperands({numElements, tagMemRef});
1176 result.addOperands(tagIndices);
1177 if (stride)
1178 result.addOperands({stride, elementsPerStride});
1179}
1180
1181void DmaStartOp::print(OpAsmPrinter &p) {
1182 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1183 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1184 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1185 if (isStrided())
1186 p << ", " << getStride() << ", " << getNumElementsPerStride();
1187
1188 p.printOptionalAttrDict((*this)->getAttrs());
1189 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1190 << ", " << getTagMemRef().getType();
1191}
1192
1193// Parse DmaStartOp.
1194// Ex:
1195// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1196// %tag[%index], %stride, %num_elt_per_stride :
1197// : memref<3076 x f32, 0>,
1198// memref<1024 x f32, 2>,
1199// memref<1 x i32>
1200//
1201ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1202 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1204 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1206 OpAsmParser::UnresolvedOperand numElementsInfo;
1207 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1210
1212 auto indexType = parser.getBuilder().getIndexType();
1213
1214 // Parse and resolve the following list of operands:
1215 // *) source memref followed by its indices (in square brackets).
1216 // *) destination memref followed by its indices (in square brackets).
1217 // *) dma size in KiB.
1218 if (parser.parseOperand(srcMemRefInfo) ||
1219 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1220 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1221 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1222 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1223 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1224 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1225 return failure();
1226
1227 // Parse optional stride and elements per stride.
1228 if (parser.parseTrailingOperandList(strideInfo))
1229 return failure();
1230
1231 bool isStrided = strideInfo.size() == 2;
1232 if (!strideInfo.empty() && !isStrided) {
1233 return parser.emitError(parser.getNameLoc(),
1234 "expected two stride related operands");
1235 }
1236
1237 if (parser.parseColonTypeList(types))
1238 return failure();
1239 if (types.size() != 3)
1240 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1241
1242 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1243 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1244 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1245 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1246 // size should be an index.
1247 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1248 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1249 // tag indices should be index.
1250 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1251 return failure();
1252
1253 if (isStrided) {
1254 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1255 return failure();
1256 }
1257
1258 return success();
1259}
1260
1261LogicalResult DmaStartOp::verify() {
1262 unsigned numOperands = getNumOperands();
1263
1264 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1265 // the number of elements.
1266 if (numOperands < 4)
1267 return emitOpError("expected at least 4 operands");
1268
1269 // Check types of operands. The order of these calls is important: the later
1270 // calls rely on some type properties to compute the operand position.
1271 // 1. Source memref.
1272 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1273 return emitOpError("expected source to be of memref type");
1274 if (numOperands < getSrcMemRefRank() + 4)
1275 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1276 << " operands";
1277 if (!getSrcIndices().empty() &&
1278 !llvm::all_of(getSrcIndices().getTypes(),
1279 [](Type t) { return t.isIndex(); }))
1280 return emitOpError("expected source indices to be of index type");
1281
1282 // 2. Destination memref.
1283 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1284 return emitOpError("expected destination to be of memref type");
1285 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1286 if (numOperands < numExpectedOperands)
1287 return emitOpError() << "expected at least " << numExpectedOperands
1288 << " operands";
1289 if (!getDstIndices().empty() &&
1290 !llvm::all_of(getDstIndices().getTypes(),
1291 [](Type t) { return t.isIndex(); }))
1292 return emitOpError("expected destination indices to be of index type");
1293
1294 // 3. Number of elements.
1295 if (!getNumElements().getType().isIndex())
1296 return emitOpError("expected num elements to be of index type");
1297
1298 // 4. Tag memref.
1299 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1300 return emitOpError("expected tag to be of memref type");
1301 numExpectedOperands += getTagMemRefRank();
1302 if (numOperands < numExpectedOperands)
1303 return emitOpError() << "expected at least " << numExpectedOperands
1304 << " operands";
1305 if (!getTagIndices().empty() &&
1306 !llvm::all_of(getTagIndices().getTypes(),
1307 [](Type t) { return t.isIndex(); }))
1308 return emitOpError("expected tag indices to be of index type");
1309
1310 // Optional stride-related operands must be either both present or both
1311 // absent.
1312 if (numOperands != numExpectedOperands &&
1313 numOperands != numExpectedOperands + 2)
1314 return emitOpError("incorrect number of operands");
1315
1316 // 5. Strides.
1317 if (isStrided()) {
1318 if (!getStride().getType().isIndex() ||
1319 !getNumElementsPerStride().getType().isIndex())
1320 return emitOpError(
1321 "expected stride and num elements per stride to be of type index");
1322 }
1323
1324 return success();
1325}
1326
1327LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1329 /// dma_start(memrefcast) -> dma_start
1330 return foldMemRefCast(*this);
1331}
1332
1333// ---------------------------------------------------------------------------
1334// DmaWaitOp
1335// ---------------------------------------------------------------------------
1336
1337LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1339 /// dma_wait(memrefcast) -> dma_wait
1340 return foldMemRefCast(*this);
1341}
1342
1343LogicalResult DmaWaitOp::verify() {
1344 // Check that the number of tag indices matches the tagMemRef rank.
1345 unsigned numTagIndices = getTagIndices().size();
1346 unsigned tagMemRefRank = getTagMemRefRank();
1347 if (numTagIndices != tagMemRefRank)
1348 return emitOpError() << "expected tagIndices to have the same number of "
1349 "elements as the tagMemRef rank, expected "
1350 << tagMemRefRank << ", but got " << numTagIndices;
1351 return success();
1352}
1353
1354//===----------------------------------------------------------------------===//
1355// ExtractAlignedPointerAsIndexOp
1356//===----------------------------------------------------------------------===//
1357
1358void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1359 function_ref<void(Value, StringRef)> setNameFn) {
1360 setNameFn(getResult(), "intptr");
1361}
1362
1363//===----------------------------------------------------------------------===//
1364// ExtractStridedMetadataOp
1365//===----------------------------------------------------------------------===//
1366
1367/// The number and type of the results are inferred from the
1368/// shape of the source.
1369LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1370 MLIRContext *context, std::optional<Location> location,
1371 ExtractStridedMetadataOp::Adaptor adaptor,
1372 SmallVectorImpl<Type> &inferredReturnTypes) {
1373 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1374 if (!sourceType)
1375 return failure();
1376
1377 unsigned sourceRank = sourceType.getRank();
1378 IndexType indexType = IndexType::get(context);
1379 auto memrefType =
1380 MemRefType::get({}, sourceType.getElementType(),
1381 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1382 // Base.
1383 inferredReturnTypes.push_back(memrefType);
1384 // Offset.
1385 inferredReturnTypes.push_back(indexType);
1386 // Sizes and strides.
1387 for (unsigned i = 0; i < sourceRank * 2; ++i)
1388 inferredReturnTypes.push_back(indexType);
1389 return success();
1390}
1391
1392void ExtractStridedMetadataOp::getAsmResultNames(
1393 function_ref<void(Value, StringRef)> setNameFn) {
1394 setNameFn(getBaseBuffer(), "base_buffer");
1395 setNameFn(getOffset(), "offset");
1396 // For multi-result to work properly with pretty names and packed syntax `x:3`
1397 // we can only give a pretty name to the first value in the pack.
1398 if (!getSizes().empty()) {
1399 setNameFn(getSizes().front(), "sizes");
1400 setNameFn(getStrides().front(), "strides");
1401 }
1402}
1403
1404/// Helper function to perform the replacement of all constant uses of `values`
1405/// by a materialized constant extracted from `maybeConstants`.
1406/// `values` and `maybeConstants` are expected to have the same size.
1407template <typename Container>
1408static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1409 Container values,
1410 ArrayRef<OpFoldResult> maybeConstants) {
1411 assert(values.size() == maybeConstants.size() &&
1412 " expected values and maybeConstants of the same size");
1413 bool atLeastOneReplacement = false;
1414 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1415 // Don't materialize a constant if there are no uses: this would indice
1416 // infinite loops in the driver.
1417 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1418 continue;
1419 assert(isa<Attribute>(maybeConstant) &&
1420 "The constified value should be either unchanged (i.e., == result) "
1421 "or a constant");
1423 rewriter, loc,
1424 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1425 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1426 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1427 // yet.
1428 op->replaceUsesOfWith(result, constantVal);
1429 atLeastOneReplacement = true;
1430 }
1431 }
1432 return atLeastOneReplacement;
1433}
1434
1435LogicalResult
1436ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1438 OpBuilder builder(*this);
1439
1440 bool atLeastOneReplacement = replaceConstantUsesOf(
1441 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1442 getConstifiedMixedOffset());
1443 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1444 getConstifiedMixedSizes());
1445 atLeastOneReplacement |= replaceConstantUsesOf(
1446 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1447
1448 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1449 if (auto prev = getSource().getDefiningOp<CastOp>())
1450 if (isa<MemRefType>(prev.getSource().getType())) {
1451 getSourceMutable().assign(prev.getSource());
1452 atLeastOneReplacement = true;
1453 }
1454
1455 return success(atLeastOneReplacement);
1456}
1457
1458SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1459 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1460 constifyIndexValues(values, getSource().getType().getShape());
1461 return values;
1462}
1463
1465ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1466 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1467 SmallVector<int64_t> staticValues;
1468 int64_t unused;
1469 LogicalResult status =
1470 getSource().getType().getStridesAndOffset(staticValues, unused);
1471 (void)status;
1472 assert(succeeded(status) && "could not get strides from type");
1473 constifyIndexValues(values, staticValues);
1474 return values;
1475}
1476
1477OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1478 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1479 SmallVector<OpFoldResult> values(1, offsetOfr);
1480 SmallVector<int64_t> staticValues, unused;
1481 int64_t offset;
1482 LogicalResult status =
1483 getSource().getType().getStridesAndOffset(unused, offset);
1484 (void)status;
1485 assert(succeeded(status) && "could not get offset from type");
1486 staticValues.push_back(offset);
1487 constifyIndexValues(values, staticValues);
1488 return values[0];
1489}
1490
1491//===----------------------------------------------------------------------===//
1492// GenericAtomicRMWOp
1493//===----------------------------------------------------------------------===//
1494
1495void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1496 Value memref, ValueRange ivs) {
1497 OpBuilder::InsertionGuard g(builder);
1498 result.addOperands(memref);
1499 result.addOperands(ivs);
1500
1501 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1502 Type elementType = memrefType.getElementType();
1503 result.addTypes(elementType);
1504
1505 Region *bodyRegion = result.addRegion();
1506 builder.createBlock(bodyRegion);
1507 bodyRegion->addArgument(elementType, memref.getLoc());
1508 }
1509}
1510
1511LogicalResult GenericAtomicRMWOp::verify() {
1512 auto &body = getRegion();
1513 if (body.getNumArguments() != 1)
1514 return emitOpError("expected single number of entry block arguments");
1515
1516 if (getResult().getType() != body.getArgument(0).getType())
1517 return emitOpError("expected block argument of the same type result type");
1518
1519 bool hasSideEffects =
1520 body.walk([&](Operation *nestedOp) {
1521 if (isMemoryEffectFree(nestedOp))
1522 return WalkResult::advance();
1523 nestedOp->emitError(
1524 "body of 'memref.generic_atomic_rmw' should contain "
1525 "only operations with no side effects");
1526 return WalkResult::interrupt();
1527 })
1528 .wasInterrupted();
1529 return hasSideEffects ? failure() : success();
1530}
1531
1532ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1535 Type memrefType;
1537
1538 Type indexType = parser.getBuilder().getIndexType();
1539 if (parser.parseOperand(memref) ||
1541 parser.parseColonType(memrefType) ||
1542 parser.resolveOperand(memref, memrefType, result.operands) ||
1543 parser.resolveOperands(ivs, indexType, result.operands))
1544 return failure();
1545
1546 Region *body = result.addRegion();
1547 if (parser.parseRegion(*body, {}) ||
1548 parser.parseOptionalAttrDict(result.attributes))
1549 return failure();
1550 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1551 return success();
1552}
1553
1554void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1555 p << ' ' << getMemref() << "[" << getIndices()
1556 << "] : " << getMemref().getType() << ' ';
1557 p.printRegion(getRegion());
1558 p.printOptionalAttrDict((*this)->getAttrs());
1559}
1560
1561//===----------------------------------------------------------------------===//
1562// AtomicYieldOp
1563//===----------------------------------------------------------------------===//
1564
1565LogicalResult AtomicYieldOp::verify() {
1566 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1567 Type resultType = getResult().getType();
1568 if (parentType != resultType)
1569 return emitOpError() << "types mismatch between yield op: " << resultType
1570 << " and its parent: " << parentType;
1571 return success();
1572}
1573
1574//===----------------------------------------------------------------------===//
1575// GlobalOp
1576//===----------------------------------------------------------------------===//
1577
1579 TypeAttr type,
1580 Attribute initialValue) {
1581 p << type;
1582 if (!op.isExternal()) {
1583 p << " = ";
1584 if (op.isUninitialized())
1585 p << "uninitialized";
1586 else
1587 p.printAttributeWithoutType(initialValue);
1588 }
1589}
1590
1591static ParseResult
1593 Attribute &initialValue) {
1594 Type type;
1595 if (parser.parseType(type))
1596 return failure();
1597
1598 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1599 if (!memrefType || !memrefType.hasStaticShape())
1600 return parser.emitError(parser.getNameLoc())
1601 << "type should be static shaped memref, but got " << type;
1602 typeAttr = TypeAttr::get(type);
1603
1604 if (parser.parseOptionalEqual())
1605 return success();
1606
1607 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1608 initialValue = UnitAttr::get(parser.getContext());
1609 return success();
1610 }
1611
1612 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1613 if (parser.parseAttribute(initialValue, tensorType))
1614 return failure();
1615 if (!llvm::isa<ElementsAttr>(initialValue))
1616 return parser.emitError(parser.getNameLoc())
1617 << "initial value should be a unit or elements attribute";
1618 return success();
1619}
1620
1621LogicalResult GlobalOp::verify() {
1622 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1623 if (!memrefType || !memrefType.hasStaticShape())
1624 return emitOpError("type should be static shaped memref, but got ")
1625 << getType();
1626
1627 // Verify that the initial value, if present, is either a unit attribute or
1628 // an elements attribute.
1629 if (getInitialValue().has_value()) {
1630 Attribute initValue = getInitialValue().value();
1631 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1632 return emitOpError("initial value should be a unit or elements "
1633 "attribute, but got ")
1634 << initValue;
1635
1636 // Check that the type of the initial value is compatible with the type of
1637 // the global variable.
1638 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1639 // Check the element types match.
1640 auto initElementType =
1641 cast<TensorType>(elementsAttr.getType()).getElementType();
1642 auto memrefElementType = memrefType.getElementType();
1643
1644 if (initElementType != memrefElementType)
1645 return emitOpError("initial value element expected to be of type ")
1646 << memrefElementType << ", but was of type " << initElementType;
1647
1648 // Check the shapes match, given that memref globals can only produce
1649 // statically shaped memrefs and elements literal type must have a static
1650 // shape we can assume both types are shaped.
1651 auto initShape = elementsAttr.getShapedType().getShape();
1652 auto memrefShape = memrefType.getShape();
1653 if (initShape != memrefShape)
1654 return emitOpError("initial value shape expected to be ")
1655 << memrefShape << " but was " << initShape;
1656 }
1657 }
1658
1659 // TODO: verify visibility for declarations.
1660 return success();
1661}
1662
1663ElementsAttr GlobalOp::getConstantInitValue() {
1664 auto initVal = getInitialValue();
1665 if (getConstant() && initVal.has_value())
1666 return llvm::cast<ElementsAttr>(initVal.value());
1667 return {};
1668}
1669
1670//===----------------------------------------------------------------------===//
1671// GetGlobalOp
1672//===----------------------------------------------------------------------===//
1673
1674LogicalResult
1675GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1676 // Verify that the result type is same as the type of the referenced
1677 // memref.global op.
1678 auto global =
1679 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1680 if (!global)
1681 return emitOpError("'")
1682 << getName() << "' does not reference a valid global memref";
1683
1684 Type resultType = getResult().getType();
1685 if (global.getType() != resultType)
1686 return emitOpError("result type ")
1687 << resultType << " does not match type " << global.getType()
1688 << " of the global memref @" << getName();
1689 return success();
1690}
1691
1692//===----------------------------------------------------------------------===//
1693// LoadOp
1694//===----------------------------------------------------------------------===//
1695
1696LogicalResult LoadOp::verify() {
1697 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1698 return emitOpError("incorrect number of indices for load, expected ")
1699 << getMemRefType().getRank() << " but got " << getIndices().size();
1700 }
1701 return success();
1702}
1703
1704OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1705 /// load(memrefcast) -> load
1706 if (succeeded(foldMemRefCast(*this)))
1707 return getResult();
1708 return OpFoldResult();
1709}
1710
1711FailureOr<std::optional<SmallVector<Value>>>
1712LoadOp::bubbleDownCasts(OpBuilder &builder) {
1714 getResult());
1715}
1716
1717//===----------------------------------------------------------------------===//
1718// MemorySpaceCastOp
1719//===----------------------------------------------------------------------===//
1720
1721void MemorySpaceCastOp::getAsmResultNames(
1722 function_ref<void(Value, StringRef)> setNameFn) {
1723 setNameFn(getResult(), "memspacecast");
1724}
1725
1726bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1727 if (inputs.size() != 1 || outputs.size() != 1)
1728 return false;
1729 Type a = inputs.front(), b = outputs.front();
1730 auto aT = llvm::dyn_cast<MemRefType>(a);
1731 auto bT = llvm::dyn_cast<MemRefType>(b);
1732
1733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1735
1736 if (aT && bT) {
1737 if (aT.getElementType() != bT.getElementType())
1738 return false;
1739 if (aT.getLayout() != bT.getLayout())
1740 return false;
1741 if (aT.getShape() != bT.getShape())
1742 return false;
1743 return true;
1744 }
1745 if (uaT && ubT) {
1746 return uaT.getElementType() == ubT.getElementType();
1747 }
1748 return false;
1749}
1750
1751OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1752 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1753 // t2)
1754 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1755 getSourceMutable().assign(parentCast.getSource());
1756 return getResult();
1757 }
1758 return Value{};
1759}
1760
1761TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1762 return getSource();
1763}
1764
1765TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1766 return getDest();
1767}
1768
1769bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1770 PtrLikeTypeInterface src) {
1771 return isa<BaseMemRefType>(tgt) &&
1772 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1773}
1774
1775MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1776 OpBuilder &b, PtrLikeTypeInterface tgt,
1778 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1779 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1780}
1781
1782/// The only cast we recognize as promotable is to the generic space.
1783bool MemorySpaceCastOp::isSourcePromotable() {
1784 return getDest().getType().getMemorySpace() == nullptr;
1785}
1786
1787//===----------------------------------------------------------------------===//
1788// PrefetchOp
1789//===----------------------------------------------------------------------===//
1790
1791void PrefetchOp::print(OpAsmPrinter &p) {
1792 p << " " << getMemref() << '[';
1794 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1795 p << ", locality<" << getLocalityHint();
1796 p << ">, " << (getIsDataCache() ? "data" : "instr");
1798 (*this)->getAttrs(),
1799 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1800 p << " : " << getMemRefType();
1801}
1802
1803ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1806 IntegerAttr localityHint;
1807 MemRefType type;
1808 StringRef readOrWrite, cacheType;
1809
1810 auto indexTy = parser.getBuilder().getIndexType();
1811 auto i32Type = parser.getBuilder().getIntegerType(32);
1812 if (parser.parseOperand(memrefInfo) ||
1814 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1815 parser.parseComma() || parser.parseKeyword("locality") ||
1816 parser.parseLess() ||
1817 parser.parseAttribute(localityHint, i32Type, "localityHint",
1818 result.attributes) ||
1819 parser.parseGreater() || parser.parseComma() ||
1820 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1821 parser.resolveOperand(memrefInfo, type, result.operands) ||
1822 parser.resolveOperands(indexInfo, indexTy, result.operands))
1823 return failure();
1824
1825 if (readOrWrite != "read" && readOrWrite != "write")
1826 return parser.emitError(parser.getNameLoc(),
1827 "rw specifier has to be 'read' or 'write'");
1828 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1829 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1830
1831 if (cacheType != "data" && cacheType != "instr")
1832 return parser.emitError(parser.getNameLoc(),
1833 "cache type has to be 'data' or 'instr'");
1834
1835 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1836 parser.getBuilder().getBoolAttr(cacheType == "data"));
1837
1838 return success();
1839}
1840
1841LogicalResult PrefetchOp::verify() {
1842 if (getNumOperands() != 1 + getMemRefType().getRank())
1843 return emitOpError("too few indices");
1844
1845 return success();
1846}
1847
1848LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1850 // prefetch(memrefcast) -> prefetch
1851 return foldMemRefCast(*this);
1852}
1853
1854//===----------------------------------------------------------------------===//
1855// RankOp
1856//===----------------------------------------------------------------------===//
1857
1858OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1859 // Constant fold rank when the rank of the operand is known.
1860 auto type = getOperand().getType();
1861 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1862 if (shapedType && shapedType.hasRank())
1863 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1864 return IntegerAttr();
1865}
1866
1867//===----------------------------------------------------------------------===//
1868// ReinterpretCastOp
1869//===----------------------------------------------------------------------===//
1870
1871void ReinterpretCastOp::getAsmResultNames(
1872 function_ref<void(Value, StringRef)> setNameFn) {
1873 setNameFn(getResult(), "reinterpret_cast");
1874}
1875
1876/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1877/// `staticSizes` and `staticStrides` are automatically filled with
1878/// source-memref-rank sentinel values that encode dynamic entries.
1879void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1880 MemRefType resultType, Value source,
1882 ArrayRef<OpFoldResult> strides,
1884 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1885 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1886 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1887 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1888 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1889 result.addAttributes(attrs);
1890 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1891 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1892 b.getDenseI64ArrayAttr(staticSizes),
1893 b.getDenseI64ArrayAttr(staticStrides));
1894}
1895
1896void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1897 Value source, OpFoldResult offset,
1899 ArrayRef<OpFoldResult> strides,
1901 auto sourceType = cast<BaseMemRefType>(source.getType());
1902 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1903 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1904 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1905 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1906 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1907 auto stridedLayout = StridedLayoutAttr::get(
1908 b.getContext(), staticOffsets.front(), staticStrides);
1909 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1910 stridedLayout, sourceType.getMemorySpace());
1911 build(b, result, resultType, source, offset, sizes, strides, attrs);
1912}
1913
1914void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1915 MemRefType resultType, Value source,
1916 int64_t offset, ArrayRef<int64_t> sizes,
1917 ArrayRef<int64_t> strides,
1919 SmallVector<OpFoldResult> sizeValues =
1920 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1921 return b.getI64IntegerAttr(v);
1922 }));
1923 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1924 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1925 return b.getI64IntegerAttr(v);
1926 }));
1927 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1928 strideValues, attrs);
1929}
1930
1931void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1932 MemRefType resultType, Value source, Value offset,
1933 ValueRange sizes, ValueRange strides,
1935 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1936 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1937 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1938 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1939 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1940}
1941
1942// TODO: ponder whether we want to allow missing trailing sizes/strides that are
1943// completed automatically, like we have for subview and extract_slice.
1944LogicalResult ReinterpretCastOp::verify() {
1945 // The source and result memrefs should be in the same memory space.
1946 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1947 auto resultType = llvm::cast<MemRefType>(getType());
1948 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1949 return emitError("different memory spaces specified for source type ")
1950 << srcType << " and result memref type " << resultType;
1951 if (srcType.getElementType() != resultType.getElementType())
1952 return emitError("different element types specified for source type ")
1953 << srcType << " and result memref type " << resultType;
1954
1955 // Match sizes in result memref type and in static_sizes attribute.
1956 for (auto [idx, resultSize, expectedSize] :
1957 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1958 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1959 return emitError("expected result type with size = ")
1960 << (ShapedType::isDynamic(expectedSize)
1961 ? std::string("dynamic")
1962 : std::to_string(expectedSize))
1963 << " instead of " << resultSize << " in dim = " << idx;
1964 }
1965
1966 // Match offset and strides in static_offset and static_strides attributes. If
1967 // result memref type has no affine map specified, this will assume an
1968 // identity layout.
1969 int64_t resultOffset;
1970 SmallVector<int64_t, 4> resultStrides;
1971 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1972 return emitError("expected result type to have strided layout but found ")
1973 << resultType;
1974
1975 // Match offset in result memref type and in static_offsets attribute.
1976 int64_t expectedOffset = getStaticOffsets().front();
1977 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1978 return emitError("expected result type with offset = ")
1979 << (ShapedType::isDynamic(expectedOffset)
1980 ? std::string("dynamic")
1981 : std::to_string(expectedOffset))
1982 << " instead of " << resultOffset;
1983
1984 // Match strides in result memref type and in static_strides attribute.
1985 for (auto [idx, resultStride, expectedStride] :
1986 llvm::enumerate(resultStrides, getStaticStrides())) {
1987 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1988 return emitError("expected result type with stride = ")
1989 << (ShapedType::isDynamic(expectedStride)
1990 ? std::string("dynamic")
1991 : std::to_string(expectedStride))
1992 << " instead of " << resultStride << " in dim = " << idx;
1993 }
1994
1995 return success();
1996}
1997
1998OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1999 Value src = getSource();
2000 auto getPrevSrc = [&]() -> Value {
2001 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
2002 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
2003 return prev.getSource();
2004
2005 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
2006 if (auto prev = src.getDefiningOp<CastOp>())
2007 return prev.getSource();
2008
2009 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2010 // are 0.
2011 if (auto prev = src.getDefiningOp<SubViewOp>())
2012 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2013 return prev.getSource();
2014
2015 return nullptr;
2016 };
2017
2018 if (auto prevSrc = getPrevSrc()) {
2019 getSourceMutable().assign(prevSrc);
2020 return getResult();
2021 }
2022
2023 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2024 if (ShapedType::isStaticShape(getType().getShape()) &&
2025 src.getType() == getType() && getStaticOffsets().front() == 0) {
2026 return src;
2027 }
2028
2029 return nullptr;
2030}
2031
2032SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2035 return values;
2036}
2037
2038SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2039 SmallVector<OpFoldResult> values = getMixedStrides();
2040 SmallVector<int64_t> staticValues;
2041 int64_t unused;
2042 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2043 (void)status;
2044 assert(succeeded(status) && "could not get strides from type");
2045 constifyIndexValues(values, staticValues);
2046 return values;
2047}
2048
2049OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2050 SmallVector<OpFoldResult> values = getMixedOffsets();
2051 assert(values.size() == 1 &&
2052 "reinterpret_cast must have one and only one offset");
2053 SmallVector<int64_t> staticValues, unused;
2054 int64_t offset;
2055 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2056 (void)status;
2057 assert(succeeded(status) && "could not get offset from type");
2058 staticValues.push_back(offset);
2059 constifyIndexValues(values, staticValues);
2060 return values[0];
2061}
2062
2063namespace {
2064/// Replace the sequence:
2065/// ```
2066/// base, offset, sizes, strides = extract_strided_metadata src
2067/// dst = reinterpret_cast base to offset, sizes, strides
2068/// ```
2069/// With
2070///
2071/// ```
2072/// dst = memref.cast src
2073/// ```
2074///
2075/// Note: The cast operation is only inserted when the type of dst and src
2076/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2077///
2078/// This pattern also matches when the offset, sizes, and strides don't come
2079/// directly from the `extract_strided_metadata`'s results but it can be
2080/// statically proven that they would hold the same values.
2081///
2082/// For instance, the following sequence would be replaced:
2083/// ```
2084/// base, offset, sizes, strides =
2085/// extract_strided_metadata memref : memref<3x4xty>
2086/// dst = reinterpret_cast base to 0, [3, 4], strides
2087/// ```
2088/// Because we know (thanks to the type of the input memref) that variable
2089/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2090///
2091/// Similarly, the following sequence would be replaced:
2092/// ```
2093/// c0 = arith.constant 0
2094/// c4 = arith.constant 4
2095/// base, offset, sizes, strides =
2096/// extract_strided_metadata memref : memref<3x4xty>
2097/// dst = reinterpret_cast base to c0, [3, c4], strides
2098/// ```
2099/// Because we know that `offset`and `c0` will hold 0
2100/// and `c4` will hold 4.
2101///
2102/// If the pattern above does not match, the input of the
2103/// extract_strided_metadata is always folded into the input of the
2104/// reinterpret_cast operator. This allows for dead code elimination to get rid
2105/// of the extract_strided_metadata in some cases.
2106struct ReinterpretCastOpExtractStridedMetadataFolder
2107 : public OpRewritePattern<ReinterpretCastOp> {
2108public:
2109 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2110
2111 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2112 PatternRewriter &rewriter) const override {
2113 auto extractStridedMetadata =
2114 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2115 if (!extractStridedMetadata)
2116 return failure();
2117
2118 // Check if the reinterpret cast reconstructs a memref with the exact same
2119 // properties as the extract strided metadata.
2120 auto isReinterpretCastNoop = [&]() -> bool {
2121 // First, check that the strides are the same.
2122 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2123 op.getConstifiedMixedStrides()))
2124 return false;
2125
2126 // Second, check the sizes.
2127 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2128 op.getConstifiedMixedSizes()))
2129 return false;
2130
2131 // Finally, check the offset.
2132 assert(op.getMixedOffsets().size() == 1 &&
2133 "reinterpret_cast with more than one offset should have been "
2134 "rejected by the verifier");
2135 return extractStridedMetadata.getConstifiedMixedOffset() ==
2136 op.getConstifiedMixedOffset();
2137 };
2138
2139 if (!isReinterpretCastNoop()) {
2140 // If the extract_strided_metadata / reinterpret_cast pair can't be
2141 // completely folded, then we could fold the input of the
2142 // extract_strided_metadata into the input of the reinterpret_cast
2143 // input. For some cases (e.g., static dimensions) the
2144 // the extract_strided_metadata is eliminated by dead code elimination.
2145 //
2146 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2147 //
2148 // We can always fold the input of a extract_strided_metadata operator
2149 // to the input of a reinterpret_cast operator, because they point to
2150 // the same memory. Note that the reinterpret_cast does not use the
2151 // layout of its input memref, only its base memory pointer which is
2152 // the same as the base pointer returned by the extract_strided_metadata
2153 // operator and the base pointer of the extract_strided_metadata memref
2154 // input.
2155 rewriter.modifyOpInPlace(op, [&]() {
2156 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2157 });
2158 return success();
2159 }
2160
2161 // At this point, we know that the back and forth between extract strided
2162 // metadata and reinterpret cast is a noop. However, the final type of the
2163 // reinterpret cast may not be exactly the same as the original memref.
2164 // E.g., it could be changing a dimension from static to dynamic. Check that
2165 // here and add a cast if necessary.
2166 Type srcTy = extractStridedMetadata.getSource().getType();
2167 if (srcTy == op.getResult().getType())
2168 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2169 else
2170 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2171 extractStridedMetadata.getSource());
2172
2173 return success();
2174 }
2175};
2176
2177struct ReinterpretCastOpConstantFolder
2178 : public OpRewritePattern<ReinterpretCastOp> {
2179public:
2180 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2181
2182 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2183 PatternRewriter &rewriter) const override {
2184 unsigned srcStaticCount = llvm::count_if(
2185 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2186 op.getMixedStrides()),
2187 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2188
2189 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2190 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2191 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2192
2193 // TODO: Using counting comparison instead of direct comparison because
2194 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2195 // IntegerAttrs, while constifyIndexValues (and therefore
2196 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2197 if (srcStaticCount ==
2198 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2199 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2200 return failure();
2201
2202 auto newReinterpretCast = ReinterpretCastOp::create(
2203 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2204
2205 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2206 return success();
2207 }
2208};
2209} // namespace
2210
2211void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2212 MLIRContext *context) {
2213 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2214 ReinterpretCastOpConstantFolder>(context);
2215}
2216
2217FailureOr<std::optional<SmallVector<Value>>>
2218ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2219 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2220}
2221
2222//===----------------------------------------------------------------------===//
2223// Reassociative reshape ops
2224//===----------------------------------------------------------------------===//
2225
2226void CollapseShapeOp::getAsmResultNames(
2227 function_ref<void(Value, StringRef)> setNameFn) {
2228 setNameFn(getResult(), "collapse_shape");
2229}
2230
2231void ExpandShapeOp::getAsmResultNames(
2232 function_ref<void(Value, StringRef)> setNameFn) {
2233 setNameFn(getResult(), "expand_shape");
2234}
2235
2236LogicalResult ExpandShapeOp::reifyResultShapes(
2237 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2238 reifiedResultShapes = {
2239 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2240 return success();
2241}
2242
2243/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2244/// result and operand. Layout maps are verified separately.
2245///
2246/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2247/// allowed in a reassocation group.
2248static LogicalResult
2250 ArrayRef<int64_t> expandedShape,
2251 ArrayRef<ReassociationIndices> reassociation,
2252 bool allowMultipleDynamicDimsPerGroup) {
2253 // There must be one reassociation group per collapsed dimension.
2254 if (collapsedShape.size() != reassociation.size())
2255 return op->emitOpError("invalid number of reassociation groups: found ")
2256 << reassociation.size() << ", expected " << collapsedShape.size();
2257
2258 // The next expected expanded dimension index (while iterating over
2259 // reassociation indices).
2260 int64_t nextDim = 0;
2261 for (const auto &it : llvm::enumerate(reassociation)) {
2262 ReassociationIndices group = it.value();
2263 int64_t collapsedDim = it.index();
2264
2265 bool foundDynamic = false;
2266 for (int64_t expandedDim : group) {
2267 if (expandedDim != nextDim++)
2268 return op->emitOpError("reassociation indices must be contiguous");
2269
2270 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2271 return op->emitOpError("reassociation index ")
2272 << expandedDim << " is out of bounds";
2273
2274 // Check if there are multiple dynamic dims in a reassociation group.
2275 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2276 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2277 return op->emitOpError(
2278 "at most one dimension in a reassociation group may be dynamic");
2279 foundDynamic = true;
2280 }
2281 }
2282
2283 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2284 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2285 return op->emitOpError("collapsed dim (")
2286 << collapsedDim
2287 << ") must be dynamic if and only if reassociation group is "
2288 "dynamic";
2289
2290 // If all dims in the reassociation group are static, the size of the
2291 // collapsed dim can be verified.
2292 if (!foundDynamic) {
2293 int64_t groupSize = 1;
2294 for (int64_t expandedDim : group)
2295 groupSize *= expandedShape[expandedDim];
2296 if (groupSize != collapsedShape[collapsedDim])
2297 return op->emitOpError("collapsed dim size (")
2298 << collapsedShape[collapsedDim]
2299 << ") must equal reassociation group size (" << groupSize << ")";
2300 }
2301 }
2302
2303 if (collapsedShape.empty()) {
2304 // Rank 0: All expanded dimensions must be 1.
2305 for (int64_t d : expandedShape)
2306 if (d != 1)
2307 return op->emitOpError(
2308 "rank 0 memrefs can only be extended/collapsed with/from ones");
2309 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2310 // Rank >= 1: Number of dimensions among all reassociation groups must match
2311 // the result memref rank.
2312 return op->emitOpError("expanded rank (")
2313 << expandedShape.size()
2314 << ") inconsistent with number of reassociation indices (" << nextDim
2315 << ")";
2316 }
2317
2318 return success();
2319}
2320
2321SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2322 return getSymbolLessAffineMaps(getReassociationExprs());
2323}
2324
2325SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2327 getReassociationIndices());
2328}
2329
2330SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2331 return getSymbolLessAffineMaps(getReassociationExprs());
2332}
2333
2334SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2336 getReassociationIndices());
2337}
2338
2339/// Compute the layout map after expanding a given source MemRef type with the
2340/// specified reassociation indices.
2341static FailureOr<StridedLayoutAttr>
2342computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2343 ArrayRef<ReassociationIndices> reassociation) {
2344 int64_t srcOffset;
2345 SmallVector<int64_t> srcStrides;
2346 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2347 return failure();
2348 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2349
2350 // 1-1 mapping between srcStrides and reassociation packs.
2351 // Each srcStride starts with the given value and gets expanded according to
2352 // the proper entries in resultShape.
2353 // Example:
2354 // srcStrides = [10000, 1 , 100 ],
2355 // reassociations = [ [0], [1], [2, 3, 4]],
2356 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2357 // -> For the purpose of stride calculation, the useful sizes are:
2358 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2359 // resultStrides = [10000, 1, 600, 200, 100]
2360 // Note that a stride does not get expanded along the first entry of each
2361 // shape pack.
2362 SmallVector<int64_t> reverseResultStrides;
2363 reverseResultStrides.reserve(resultShape.size());
2364 unsigned shapeIndex = resultShape.size() - 1;
2365 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2366 ReassociationIndices reassoc = std::get<0>(it);
2367 int64_t currentStrideToExpand = std::get<1>(it);
2368 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2369 reverseResultStrides.push_back(currentStrideToExpand);
2370 currentStrideToExpand =
2371 (SaturatedInteger::wrap(currentStrideToExpand) *
2372 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2373 .asInteger();
2374 }
2375 }
2376 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2377 resultStrides.resize(resultShape.size(), 1);
2378 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2379}
2380
2381FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2382 MemRefType srcType, ArrayRef<int64_t> resultShape,
2383 ArrayRef<ReassociationIndices> reassociation) {
2384 if (srcType.getLayout().isIdentity()) {
2385 // If the source is contiguous (i.e., no layout map specified), so is the
2386 // result.
2387 MemRefLayoutAttrInterface layout;
2388 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2389 srcType.getMemorySpace());
2390 }
2391
2392 // Source may not be contiguous. Compute the layout map.
2393 FailureOr<StridedLayoutAttr> computedLayout =
2394 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2395 if (failed(computedLayout))
2396 return failure();
2397 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2398 srcType.getMemorySpace());
2399}
2400
2401FailureOr<SmallVector<OpFoldResult>>
2402ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2403 MemRefType expandedType,
2404 ArrayRef<ReassociationIndices> reassociation,
2405 ArrayRef<OpFoldResult> inputShape) {
2406 std::optional<SmallVector<OpFoldResult>> outputShape =
2407 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2408 inputShape);
2409 if (!outputShape)
2410 return failure();
2411 return *outputShape;
2412}
2413
2414void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2415 Type resultType, Value src,
2416 ArrayRef<ReassociationIndices> reassociation,
2417 ArrayRef<OpFoldResult> outputShape) {
2418 auto [staticOutputShape, dynamicOutputShape] =
2419 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2420 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2421 getReassociationIndicesAttribute(builder, reassociation),
2422 dynamicOutputShape, staticOutputShape);
2423}
2424
2425void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2426 Type resultType, Value src,
2427 ArrayRef<ReassociationIndices> reassociation) {
2428 SmallVector<OpFoldResult> inputShape =
2429 getMixedSizes(builder, result.location, src);
2430 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2431 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2432 builder, result.location, memrefResultTy, reassociation, inputShape);
2433 // Failure of this assertion usually indicates presence of multiple
2434 // dynamic dimensions in the same reassociation group.
2435 assert(succeeded(outputShape) && "unable to infer output shape");
2436 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2437}
2438
2439void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2440 ArrayRef<int64_t> resultShape, Value src,
2441 ArrayRef<ReassociationIndices> reassociation) {
2442 // Only ranked memref source values are supported.
2443 auto srcType = llvm::cast<MemRefType>(src.getType());
2444 FailureOr<MemRefType> resultType =
2445 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2446 // Failure of this assertion usually indicates a problem with the source
2447 // type, e.g., could not get strides/offset.
2448 assert(succeeded(resultType) && "could not compute layout");
2449 build(builder, result, *resultType, src, reassociation);
2450}
2451
2452void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2453 ArrayRef<int64_t> resultShape, Value src,
2454 ArrayRef<ReassociationIndices> reassociation,
2455 ArrayRef<OpFoldResult> outputShape) {
2456 // Only ranked memref source values are supported.
2457 auto srcType = llvm::cast<MemRefType>(src.getType());
2458 FailureOr<MemRefType> resultType =
2459 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2460 // Failure of this assertion usually indicates a problem with the source
2461 // type, e.g., could not get strides/offset.
2462 assert(succeeded(resultType) && "could not compute layout");
2463 build(builder, result, *resultType, src, reassociation, outputShape);
2464}
2465
2466LogicalResult ExpandShapeOp::verify() {
2467 MemRefType srcType = getSrcType();
2468 MemRefType resultType = getResultType();
2469
2470 if (srcType.getRank() > resultType.getRank()) {
2471 auto r0 = srcType.getRank();
2472 auto r1 = resultType.getRank();
2473 return emitOpError("has source rank ")
2474 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2475 << r0 << " > " << r1 << ").";
2476 }
2477
2478 // Verify result shape.
2479 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2480 resultType.getShape(),
2481 getReassociationIndices(),
2482 /*allowMultipleDynamicDimsPerGroup=*/true)))
2483 return failure();
2484
2485 // Compute expected result type (including layout map).
2486 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2487 srcType, resultType.getShape(), getReassociationIndices());
2488 if (failed(expectedResultType))
2489 return emitOpError("invalid source layout map");
2490
2491 // Check actual result type.
2492 if (*expectedResultType != resultType)
2493 return emitOpError("expected expanded type to be ")
2494 << *expectedResultType << " but found " << resultType;
2495
2496 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2497 return emitOpError("expected number of static shape bounds to be equal to "
2498 "the output rank (")
2499 << resultType.getRank() << ") but found "
2500 << getStaticOutputShape().size() << " inputs instead";
2501
2502 if ((int64_t)getOutputShape().size() !=
2503 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2504 return emitOpError("mismatch in dynamic dims in output_shape and "
2505 "static_output_shape: static_output_shape has ")
2506 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2507 << " dynamic dims while output_shape has " << getOutputShape().size()
2508 << " values";
2509
2510 // Verify if provided output shapes are in agreement with output type.
2511 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2512 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2513 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2514 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2515 return emitOpError("invalid output shape provided at pos ") << pos;
2516 }
2517 }
2518
2519 return success();
2520}
2521
2522struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
2523public:
2524 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2525
2526 LogicalResult matchAndRewrite(ExpandShapeOp op,
2527 PatternRewriter &rewriter) const override {
2528 auto cast = op.getSrc().getDefiningOp<CastOp>();
2529 if (!cast)
2530 return failure();
2531
2532 if (!CastOp::canFoldIntoConsumerOp(cast))
2533 return failure();
2534
2535 SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
2536 SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
2537 SmallVector<int64_t> newOutputShapeSizes;
2538
2539 // Convert output shape dims from dynamic to static where possible.
2540 for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2541 std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
2542 if (!sizeOpt.has_value()) {
2543 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2544 continue;
2545 }
2546
2547 newOutputShapeSizes.push_back(sizeOpt.value());
2548 newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
2549 }
2550
2551 Value castSource = cast.getSource();
2552 auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
2553 SmallVector<ReassociationIndices> reassociationIndices =
2554 op.getReassociationIndices();
2555 for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2556 auto newOutputShapeSizesSlice =
2557 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2558 bool newOutputDynamic =
2559 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2560 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2561 return rewriter.notifyMatchFailure(
2562 op, "folding cast will result in changing dynamicity in "
2563 "reassociation group");
2564 }
2565
2566 FailureOr<MemRefType> newResultTypeOrFailure =
2567 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2568 reassociationIndices);
2569
2570 if (failed(newResultTypeOrFailure))
2571 return rewriter.notifyMatchFailure(
2572 op, "could not compute new expanded type after folding cast");
2573
2574 if (*newResultTypeOrFailure == op.getResultType()) {
2575 rewriter.modifyOpInPlace(
2576 op, [&]() { op.getSrcMutable().assign(castSource); });
2577 } else {
2578 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2579 *newResultTypeOrFailure, castSource,
2580 reassociationIndices, newOutputShape);
2581 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2582 }
2583 return success();
2584 }
2585};
2586
2587void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2588 MLIRContext *context) {
2589 results.add<
2590 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2591 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2592 ExpandShapeOpMemRefCastFolder>(context);
2593}
2594
2595FailureOr<std::optional<SmallVector<Value>>>
2596ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2597 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2598}
2599
2600/// Compute the layout map after collapsing a given source MemRef type with the
2601/// specified reassociation indices.
2602///
2603/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2604/// not possible to check this by inspecting a MemRefType in the general case.
2605/// If non-contiguity cannot be checked statically, the collapse is assumed to
2606/// be valid (and thus accepted by this function) unless `strict = true`.
2607static FailureOr<StridedLayoutAttr>
2608computeCollapsedLayoutMap(MemRefType srcType,
2609 ArrayRef<ReassociationIndices> reassociation,
2610 bool strict = false) {
2611 int64_t srcOffset;
2612 SmallVector<int64_t> srcStrides;
2613 auto srcShape = srcType.getShape();
2614 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2615 return failure();
2616
2617 // The result stride of a reassociation group is the stride of the last entry
2618 // of the reassociation. (TODO: Should be the minimum stride in the
2619 // reassociation because strides are not necessarily sorted. E.g., when using
2620 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2621 // strides are meaningless and could have any arbitrary value.
2622 SmallVector<int64_t> resultStrides;
2623 resultStrides.reserve(reassociation.size());
2624 for (const ReassociationIndices &reassoc : reassociation) {
2625 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2626 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2627 ref = ref.drop_back();
2628 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2629 resultStrides.push_back(srcStrides[ref.back()]);
2630 } else {
2631 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2632 // the corresponding stride may have to be skipped. (See above comment.)
2633 // Therefore, the result stride cannot be statically determined and must
2634 // be dynamic.
2635 resultStrides.push_back(ShapedType::kDynamic);
2636 }
2637 }
2638
2639 // Validate that each reassociation group is contiguous.
2640 unsigned resultStrideIndex = resultStrides.size() - 1;
2641 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2642 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2643 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2644 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2645 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2646
2647 // Dimensions of size 1 should be skipped, because their strides are
2648 // meaningless and could have any arbitrary value.
2649 if (srcShape[idx - 1] == 1)
2650 continue;
2651
2652 // Both source and result stride must have the same static value. In that
2653 // case, we can be sure, that the dimensions are collapsible (because they
2654 // are contiguous).
2655 // If `strict = false` (default during op verification), we accept cases
2656 // where one or both strides are dynamic. This is best effort: We reject
2657 // ops where obviously non-contiguous dims are collapsed, but accept ops
2658 // where we cannot be sure statically. Such ops may fail at runtime. See
2659 // the op documentation for details.
2660 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2661 if (strict && (stride.saturated || srcStride.saturated))
2662 return failure();
2663
2664 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2665 return failure();
2666 }
2667 }
2668 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2669}
2670
2671bool CollapseShapeOp::isGuaranteedCollapsible(
2672 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2673 // MemRefs with identity layout are always collapsible.
2674 if (srcType.getLayout().isIdentity())
2675 return true;
2676
2677 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2678 /*strict=*/true));
2679}
2680
2681MemRefType CollapseShapeOp::computeCollapsedType(
2682 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2683 SmallVector<int64_t> resultShape;
2684 resultShape.reserve(reassociation.size());
2685 for (const ReassociationIndices &group : reassociation) {
2686 auto groupSize = SaturatedInteger::wrap(1);
2687 for (int64_t srcDim : group)
2688 groupSize =
2689 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2690 resultShape.push_back(groupSize.asInteger());
2691 }
2692
2693 if (srcType.getLayout().isIdentity()) {
2694 // If the source is contiguous (i.e., no layout map specified), so is the
2695 // result.
2696 MemRefLayoutAttrInterface layout;
2697 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2698 srcType.getMemorySpace());
2699 }
2700
2701 // Source may not be fully contiguous. Compute the layout map.
2702 // Note: Dimensions that are collapsed into a single dim are assumed to be
2703 // contiguous.
2704 FailureOr<StridedLayoutAttr> computedLayout =
2705 computeCollapsedLayoutMap(srcType, reassociation);
2706 assert(succeeded(computedLayout) &&
2707 "invalid source layout map or collapsing non-contiguous dims");
2708 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2709 srcType.getMemorySpace());
2710}
2711
2712void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2713 ArrayRef<ReassociationIndices> reassociation,
2714 ArrayRef<NamedAttribute> attrs) {
2715 auto srcType = llvm::cast<MemRefType>(src.getType());
2716 MemRefType resultType =
2717 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2719 getReassociationIndicesAttribute(b, reassociation));
2720 build(b, result, resultType, src, attrs);
2721}
2722
2723LogicalResult CollapseShapeOp::verify() {
2724 MemRefType srcType = getSrcType();
2725 MemRefType resultType = getResultType();
2726
2727 if (srcType.getRank() < resultType.getRank()) {
2728 auto r0 = srcType.getRank();
2729 auto r1 = resultType.getRank();
2730 return emitOpError("has source rank ")
2731 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2732 << r0 << " < " << r1 << ").";
2733 }
2734
2735 // Verify result shape.
2736 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2737 srcType.getShape(), getReassociationIndices(),
2738 /*allowMultipleDynamicDimsPerGroup=*/true)))
2739 return failure();
2740
2741 // Compute expected result type (including layout map).
2742 MemRefType expectedResultType;
2743 if (srcType.getLayout().isIdentity()) {
2744 // If the source is contiguous (i.e., no layout map specified), so is the
2745 // result.
2746 MemRefLayoutAttrInterface layout;
2747 expectedResultType =
2748 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2749 srcType.getMemorySpace());
2750 } else {
2751 // Source may not be fully contiguous. Compute the layout map.
2752 // Note: Dimensions that are collapsed into a single dim are assumed to be
2753 // contiguous.
2754 FailureOr<StridedLayoutAttr> computedLayout =
2755 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2756 if (failed(computedLayout))
2757 return emitOpError(
2758 "invalid source layout map or collapsing non-contiguous dims");
2759 expectedResultType =
2760 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2761 *computedLayout, srcType.getMemorySpace());
2762 }
2763
2764 if (expectedResultType != resultType)
2765 return emitOpError("expected collapsed type to be ")
2766 << expectedResultType << " but found " << resultType;
2767
2768 return success();
2769}
2770
2772 : public OpRewritePattern<CollapseShapeOp> {
2773public:
2774 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2775
2776 LogicalResult matchAndRewrite(CollapseShapeOp op,
2777 PatternRewriter &rewriter) const override {
2778 auto cast = op.getOperand().getDefiningOp<CastOp>();
2779 if (!cast)
2780 return failure();
2781
2782 if (!CastOp::canFoldIntoConsumerOp(cast))
2783 return failure();
2784
2785 Type newResultType = CollapseShapeOp::computeCollapsedType(
2786 llvm::cast<MemRefType>(cast.getOperand().getType()),
2787 op.getReassociationIndices());
2788
2789 if (newResultType == op.getResultType()) {
2790 rewriter.modifyOpInPlace(
2791 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2792 } else {
2793 Value newOp =
2794 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2795 op.getReassociationIndices());
2796 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2797 }
2798 return success();
2799 }
2800};
2801
2802void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2803 MLIRContext *context) {
2804 results.add<
2805 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2806 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2807 memref::DimOp, MemRefType>,
2808 CollapseShapeOpMemRefCastFolder>(context);
2809}
2810
2811OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2813 adaptor.getOperands());
2814}
2815
2816OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2818 adaptor.getOperands());
2819}
2820
2821FailureOr<std::optional<SmallVector<Value>>>
2822CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2823 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2824}
2825
2826//===----------------------------------------------------------------------===//
2827// ReshapeOp
2828//===----------------------------------------------------------------------===//
2829
2830void ReshapeOp::getAsmResultNames(
2831 function_ref<void(Value, StringRef)> setNameFn) {
2832 setNameFn(getResult(), "reshape");
2833}
2834
2835LogicalResult ReshapeOp::verify() {
2836 Type operandType = getSource().getType();
2837 Type resultType = getResult().getType();
2838
2839 Type operandElementType =
2840 llvm::cast<ShapedType>(operandType).getElementType();
2841 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2842 if (operandElementType != resultElementType)
2843 return emitOpError("element types of source and destination memref "
2844 "types should be the same");
2845
2846 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2847 if (!operandMemRefType.getLayout().isIdentity())
2848 return emitOpError("source memref type should have identity affine map");
2849
2850 int64_t shapeSize =
2851 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2852 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2853 if (resultMemRefType) {
2854 if (!resultMemRefType.getLayout().isIdentity())
2855 return emitOpError("result memref type should have identity affine map");
2856 if (shapeSize == ShapedType::kDynamic)
2857 return emitOpError("cannot use shape operand with dynamic length to "
2858 "reshape to statically-ranked memref type");
2859 if (shapeSize != resultMemRefType.getRank())
2860 return emitOpError(
2861 "length of shape operand differs from the result's memref rank");
2862 }
2863 return success();
2864}
2865
2866FailureOr<std::optional<SmallVector<Value>>>
2867ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2868 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2869}
2870
2871//===----------------------------------------------------------------------===//
2872// StoreOp
2873//===----------------------------------------------------------------------===//
2874
2875LogicalResult StoreOp::verify() {
2876 if (getNumOperands() != 2 + getMemRefType().getRank())
2877 return emitOpError("store index operand count not equal to memref rank");
2878
2879 return success();
2880}
2881
2882LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2883 SmallVectorImpl<OpFoldResult> &results) {
2884 /// store(memrefcast) -> store
2885 return foldMemRefCast(*this, getValueToStore());
2886}
2887
2888FailureOr<std::optional<SmallVector<Value>>>
2889StoreOp::bubbleDownCasts(OpBuilder &builder) {
2891 ValueRange());
2892}
2893
2894//===----------------------------------------------------------------------===//
2895// SubViewOp
2896//===----------------------------------------------------------------------===//
2897
2898void SubViewOp::getAsmResultNames(
2899 function_ref<void(Value, StringRef)> setNameFn) {
2900 setNameFn(getResult(), "subview");
2901}
2902
2903/// A subview result type can be fully inferred from the source type and the
2904/// static representation of offsets, sizes and strides. Special sentinels
2905/// encode the dynamic case.
2906MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2907 ArrayRef<int64_t> staticOffsets,
2908 ArrayRef<int64_t> staticSizes,
2909 ArrayRef<int64_t> staticStrides) {
2910 unsigned rank = sourceMemRefType.getRank();
2911 (void)rank;
2912 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2913 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2914 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2915
2916 // Extract source offset and strides.
2917 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2918
2919 // Compute target offset whose value is:
2920 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2921 int64_t targetOffset = sourceOffset;
2922 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2923 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2924 targetOffset = (SaturatedInteger::wrap(targetOffset) +
2925 SaturatedInteger::wrap(staticOffset) *
2926 SaturatedInteger::wrap(sourceStride))
2927 .asInteger();
2928 }
2929
2930 // Compute target stride whose value is:
2931 // `sourceStrides_i * staticStrides_i`.
2932 SmallVector<int64_t, 4> targetStrides;
2933 targetStrides.reserve(staticOffsets.size());
2934 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2935 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2936 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2937 SaturatedInteger::wrap(staticStride))
2938 .asInteger());
2939 }
2940
2941 // The type is now known.
2942 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2943 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2944 targetOffset, targetStrides),
2945 sourceMemRefType.getMemorySpace());
2946}
2947
2948MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2949 ArrayRef<OpFoldResult> offsets,
2950 ArrayRef<OpFoldResult> sizes,
2951 ArrayRef<OpFoldResult> strides) {
2952 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2953 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2954 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2955 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2956 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2957 if (!hasValidSizesOffsets(staticOffsets))
2958 return {};
2959 if (!hasValidSizesOffsets(staticSizes))
2960 return {};
2961 if (!hasValidStrides(staticStrides))
2962 return {};
2963 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2964 staticSizes, staticStrides);
2965}
2966
2967MemRefType SubViewOp::inferRankReducedResultType(
2968 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2969 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2970 ArrayRef<int64_t> strides) {
2971 MemRefType inferredType =
2972 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2973 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2974 "expected ");
2975 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2976 return inferredType;
2977
2978 // Compute which dimensions are dropped.
2979 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2980 computeRankReductionMask(inferredType.getShape(), resultShape);
2981 assert(dimsToProject.has_value() && "invalid rank reduction");
2982
2983 // Compute the layout and result type.
2984 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2985 SmallVector<int64_t> rankReducedStrides;
2986 rankReducedStrides.reserve(resultShape.size());
2987 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2988 if (!dimsToProject->contains(idx))
2989 rankReducedStrides.push_back(value);
2990 }
2991 return MemRefType::get(resultShape, inferredType.getElementType(),
2992 StridedLayoutAttr::get(inferredLayout.getContext(),
2993 inferredLayout.getOffset(),
2994 rankReducedStrides),
2995 inferredType.getMemorySpace());
2996}
2997
2998MemRefType SubViewOp::inferRankReducedResultType(
2999 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3000 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3001 ArrayRef<OpFoldResult> strides) {
3002 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3003 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3004 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3005 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3006 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3007 return SubViewOp::inferRankReducedResultType(
3008 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3009 staticStrides);
3010}
3011
3012// Build a SubViewOp with mixed static and dynamic entries and custom result
3013// type. If the type passed is nullptr, it is inferred.
3014void SubViewOp::build(OpBuilder &b, OperationState &result,
3015 MemRefType resultType, Value source,
3016 ArrayRef<OpFoldResult> offsets,
3017 ArrayRef<OpFoldResult> sizes,
3018 ArrayRef<OpFoldResult> strides,
3019 ArrayRef<NamedAttribute> attrs) {
3020 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3021 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3022 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3023 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3024 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3025 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
3026 // Structuring implementation this way avoids duplication between builders.
3027 if (!resultType) {
3028 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3029 staticSizes, staticStrides);
3030 }
3031 result.addAttributes(attrs);
3032 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
3033 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3034 b.getDenseI64ArrayAttr(staticSizes),
3035 b.getDenseI64ArrayAttr(staticStrides));
3036}
3037
3038// Build a SubViewOp with mixed static and dynamic entries and inferred result
3039// type.
3040void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3041 ArrayRef<OpFoldResult> offsets,
3042 ArrayRef<OpFoldResult> sizes,
3043 ArrayRef<OpFoldResult> strides,
3044 ArrayRef<NamedAttribute> attrs) {
3045 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3046}
3047
3048// Build a SubViewOp with static entries and inferred result type.
3049void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3050 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3051 ArrayRef<int64_t> strides,
3052 ArrayRef<NamedAttribute> attrs) {
3053 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3054 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3055 return b.getI64IntegerAttr(v);
3056 }));
3057 SmallVector<OpFoldResult> sizeValues =
3058 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3059 return b.getI64IntegerAttr(v);
3060 }));
3061 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3062 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3063 return b.getI64IntegerAttr(v);
3064 }));
3065 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
3066}
3067
3068// Build a SubViewOp with dynamic entries and custom result type. If the
3069// type passed is nullptr, it is inferred.
3070void SubViewOp::build(OpBuilder &b, OperationState &result,
3071 MemRefType resultType, Value source,
3072 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3073 ArrayRef<int64_t> strides,
3074 ArrayRef<NamedAttribute> attrs) {
3075 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3076 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3077 return b.getI64IntegerAttr(v);
3078 }));
3079 SmallVector<OpFoldResult> sizeValues =
3080 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3081 return b.getI64IntegerAttr(v);
3082 }));
3083 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3084 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3085 return b.getI64IntegerAttr(v);
3086 }));
3087 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3088 attrs);
3089}
3090
3091// Build a SubViewOp with dynamic entries and custom result type. If the type
3092// passed is nullptr, it is inferred.
3093void SubViewOp::build(OpBuilder &b, OperationState &result,
3094 MemRefType resultType, Value source, ValueRange offsets,
3095 ValueRange sizes, ValueRange strides,
3096 ArrayRef<NamedAttribute> attrs) {
3097 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3098 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3099 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3100 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3101 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3102 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3103 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3104}
3105
3106// Build a SubViewOp with dynamic entries and inferred result type.
3107void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3108 ValueRange offsets, ValueRange sizes, ValueRange strides,
3109 ArrayRef<NamedAttribute> attrs) {
3110 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3111}
3112
3113/// For ViewLikeOpInterface.
3114Value SubViewOp::getViewSource() { return getSource(); }
3115
3116/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3117/// static value).
3118static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3119 int64_t t1Offset, t2Offset;
3120 SmallVector<int64_t> t1Strides, t2Strides;
3121 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3122 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3123 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3124}
3125
3126/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3127/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3128/// marked as dropped in `droppedDims`.
3129static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3130 const llvm::SmallBitVector &droppedDims) {
3131 assert(size_t(t1.getRank()) == droppedDims.size() &&
3132 "incorrect number of bits");
3133 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3134 "incorrect number of dropped dims");
3135 int64_t t1Offset, t2Offset;
3136 SmallVector<int64_t> t1Strides, t2Strides;
3137 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3138 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3139 if (failed(res1) || failed(res2))
3140 return false;
3141 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3142 if (droppedDims[i])
3143 continue;
3144 if (t1Strides[i] != t2Strides[j])
3145 return false;
3146 ++j;
3147 }
3148 return true;
3149}
3150
3152 SubViewOp op, Type expectedType) {
3153 auto memrefType = llvm::cast<ShapedType>(expectedType);
3154 switch (result) {
3156 return success();
3158 return op->emitError("expected result rank to be smaller or equal to ")
3159 << "the source rank, but got " << op.getType();
3161 return op->emitError("expected result type to be ")
3162 << expectedType
3163 << " or a rank-reduced version. (mismatch of result sizes), but got "
3164 << op.getType();
3166 return op->emitError("expected result element type to be ")
3167 << memrefType.getElementType() << ", but got " << op.getType();
3169 return op->emitError(
3170 "expected result and source memory spaces to match, but got ")
3171 << op.getType();
3173 return op->emitError("expected result type to be ")
3174 << expectedType
3175 << " or a rank-reduced version. (mismatch of result layout), but "
3176 "got "
3177 << op.getType();
3178 }
3179 llvm_unreachable("unexpected subview verification result");
3180}
3181
3182/// Verifier for SubViewOp.
3183LogicalResult SubViewOp::verify() {
3184 MemRefType baseType = getSourceType();
3185 MemRefType subViewType = getType();
3186 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3187 ArrayRef<int64_t> staticSizes = getStaticSizes();
3188 ArrayRef<int64_t> staticStrides = getStaticStrides();
3189
3190 // The base memref and the view memref should be in the same memory space.
3191 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3192 return emitError("different memory spaces specified for base memref "
3193 "type ")
3194 << baseType << " and subview memref type " << subViewType;
3195
3196 // Verify that the base memref type has a strided layout map.
3197 if (!baseType.isStrided())
3198 return emitError("base type ") << baseType << " is not strided";
3199
3200 // Compute the expected result type, assuming that there are no rank
3201 // reductions.
3202 MemRefType expectedType = SubViewOp::inferResultType(
3203 baseType, staticOffsets, staticSizes, staticStrides);
3204
3205 // Verify all properties of a shaped type: rank, element type and dimension
3206 // sizes. This takes into account potential rank reductions.
3207 auto shapedTypeVerification = isRankReducedType(
3208 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3209 if (shapedTypeVerification != SliceVerificationResult::Success)
3210 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3211
3212 // Make sure that the memory space did not change.
3213 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3215 *this, expectedType);
3216
3217 // Verify the offset of the layout map.
3218 if (!haveCompatibleOffsets(expectedType, subViewType))
3220 *this, expectedType);
3221
3222 // The only thing that's left to verify now are the strides. First, compute
3223 // the unused dimensions due to rank reductions. We have to look at sizes and
3224 // strides to decide which dimensions were dropped. This function also
3225 // partially verifies strides in case of rank reductions.
3226 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3227 getMixedSizes());
3228 if (failed(unusedDims))
3230 *this, expectedType);
3231
3232 // Strides must match.
3233 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3235 *this, expectedType);
3236
3237 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3238 // to the base memref.
3239 SliceBoundsVerificationResult boundsResult =
3240 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3241 staticStrides, /*generateErrorMessage=*/true);
3242 if (!boundsResult.isValid)
3243 return getOperation()->emitError(boundsResult.errorMessage);
3244
3245 return success();
3246}
3247
3249 return os << "range " << range.offset << ":" << range.size << ":"
3250 << range.stride;
3251}
3252
3253/// Return the list of Range (i.e. offset, size, stride). Each Range
3254/// entry contains either the dynamic value or a ConstantIndexOp constructed
3255/// with `b` at location `loc`.
3256SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3257 OpBuilder &b, Location loc) {
3258 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3259 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3260 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3262 unsigned rank = ranks[0];
3263 res.reserve(rank);
3264 for (unsigned idx = 0; idx < rank; ++idx) {
3265 Value offset =
3266 op.isDynamicOffset(idx)
3267 ? op.getDynamicOffset(idx)
3268 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3269 Value size =
3270 op.isDynamicSize(idx)
3271 ? op.getDynamicSize(idx)
3272 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3273 Value stride =
3274 op.isDynamicStride(idx)
3275 ? op.getDynamicStride(idx)
3276 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3277 res.emplace_back(Range{offset, size, stride});
3278 }
3279 return res;
3280}
3281
3282/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3283/// to deduce the result type for the given `sourceType`. Additionally, reduce
3284/// the rank of the inferred result type if `currentResultType` is lower rank
3285/// than `currentSourceType`. Use this signature if `sourceType` is updated
3286/// together with the result type. In this case, it is important to compute
3287/// the dropped dimensions using `currentSourceType` whose strides align with
3288/// `currentResultType`.
3290 MemRefType currentResultType, MemRefType currentSourceType,
3291 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3292 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3293 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3294 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3295 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3296 currentSourceType, currentResultType, mixedSizes);
3297 if (failed(unusedDims))
3298 return nullptr;
3299
3300 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3301 SmallVector<int64_t> shape, strides;
3302 unsigned numDimsAfterReduction =
3303 nonRankReducedType.getRank() - unusedDims->count();
3304 shape.reserve(numDimsAfterReduction);
3305 strides.reserve(numDimsAfterReduction);
3306 for (const auto &[idx, size, stride] :
3307 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3308 nonRankReducedType.getShape(), layout.getStrides())) {
3309 if (unusedDims->test(idx))
3310 continue;
3311 shape.push_back(size);
3312 strides.push_back(stride);
3313 }
3314
3315 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3316 StridedLayoutAttr::get(sourceType.getContext(),
3317 layout.getOffset(), strides),
3318 nonRankReducedType.getMemorySpace());
3319}
3320
3322 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3323 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3324 unsigned rank = memrefType.getRank();
3325 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3327 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3328 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3329 targetShape, memrefType, offsets, sizes, strides);
3330 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3331 sizes, strides);
3332}
3333
3334FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3335 Value value,
3336 ArrayRef<int64_t> desiredShape) {
3337 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3338 assert(sourceMemrefType && "not a ranked memref type");
3339 auto sourceShape = sourceMemrefType.getShape();
3340 if (sourceShape.equals(desiredShape))
3341 return value;
3342 auto maybeRankReductionMask =
3343 mlir::computeRankReductionMask(sourceShape, desiredShape);
3344 if (!maybeRankReductionMask)
3345 return failure();
3346 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3347}
3348
3349/// Helper method to check if a `subview` operation is trivially a no-op. This
3350/// is the case if the all offsets are zero, all strides are 1, and the source
3351/// shape is same as the size of the subview. In such cases, the subview can
3352/// be folded into its source.
3353static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3354 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3355 return false;
3356
3357 auto mixedOffsets = subViewOp.getMixedOffsets();
3358 auto mixedSizes = subViewOp.getMixedSizes();
3359 auto mixedStrides = subViewOp.getMixedStrides();
3360
3361 // Check offsets are zero.
3362 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3363 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3364 return !intValue || intValue.value() != 0;
3365 }))
3366 return false;
3367
3368 // Check strides are one.
3369 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3370 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3371 return !intValue || intValue.value() != 1;
3372 }))
3373 return false;
3374
3375 // Check all size values are static and matches the (static) source shape.
3376 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3377 for (const auto &size : llvm::enumerate(mixedSizes)) {
3378 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3379 if (!intValue || *intValue != sourceShape[size.index()])
3380 return false;
3381 }
3382 // All conditions met. The `SubViewOp` is foldable as a no-op.
3383 return true;
3384}
3385
3386namespace {
3387/// Pattern to rewrite a subview op with MemRefCast arguments.
3388/// This essentially pushes memref.cast past its consuming subview when
3389/// `canFoldIntoConsumerOp` is true.
3390///
3391/// Example:
3392/// ```
3393/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3394/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3395/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3396/// ```
3397/// is rewritten into:
3398/// ```
3399/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3400/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3401/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3402/// ```
3403class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3404public:
3405 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3406
3407 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3408 PatternRewriter &rewriter) const override {
3409 // Any constant operand, just return to let SubViewOpConstantFolder kick
3410 // in.
3411 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3412 return matchPattern(operand, matchConstantIndex());
3413 }))
3414 return failure();
3415
3416 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3417 if (!castOp)
3418 return failure();
3419
3420 if (!CastOp::canFoldIntoConsumerOp(castOp))
3421 return failure();
3422
3423 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3424 // the MemRefCastOp source operand type to infer the result type and the
3425 // current SubViewOp source operand type to compute the dropped dimensions
3426 // if the operation is rank-reducing.
3427 auto resultType = getCanonicalSubViewResultType(
3428 subViewOp.getType(), subViewOp.getSourceType(),
3429 llvm::cast<MemRefType>(castOp.getSource().getType()),
3430 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3431 subViewOp.getMixedStrides());
3432 if (!resultType)
3433 return failure();
3434
3435 Value newSubView = SubViewOp::create(
3436 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3437 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3438 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3439 subViewOp.getStaticStrides());
3440 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3441 newSubView);
3442 return success();
3443 }
3444};
3445
3446/// Canonicalize subview ops that are no-ops. When the source shape is not
3447/// same as a result shape due to use of `affine_map`.
3448class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3449public:
3450 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3451
3452 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3453 PatternRewriter &rewriter) const override {
3454 if (!isTrivialSubViewOp(subViewOp))
3455 return failure();
3456 if (subViewOp.getSourceType() == subViewOp.getType()) {
3457 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3458 return success();
3459 }
3460 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3461 subViewOp.getSource());
3462 return success();
3463 }
3464};
3465} // namespace
3466
3467/// Return the canonical type of the result of a subview.
3469 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3470 ArrayRef<OpFoldResult> mixedSizes,
3471 ArrayRef<OpFoldResult> mixedStrides) {
3472 // Infer a memref type without taking into account any rank reductions.
3473 MemRefType resTy = SubViewOp::inferResultType(
3474 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3475 if (!resTy)
3476 return {};
3477 MemRefType nonReducedType = resTy;
3478
3479 // Directly return the non-rank reduced type if there are no dropped dims.
3480 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3481 if (droppedDims.none())
3482 return nonReducedType;
3483
3484 // Take the strides and offset from the non-rank reduced type.
3485 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3486
3487 // Drop dims from shape and strides.
3488 SmallVector<int64_t> targetShape;
3489 SmallVector<int64_t> targetStrides;
3490 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3491 if (droppedDims.test(i))
3492 continue;
3493 targetStrides.push_back(nonReducedStrides[i]);
3494 targetShape.push_back(nonReducedType.getDimSize(i));
3495 }
3496
3497 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3498 StridedLayoutAttr::get(nonReducedType.getContext(),
3499 offset, targetStrides),
3500 nonReducedType.getMemorySpace());
3501 }
3502};
3503
3504/// A canonicalizer wrapper to replace SubViewOps.
3506 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3507 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3508 }
3509};
3510
3511void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3512 MLIRContext *context) {
3513 results
3514 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3515 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3516 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3517}
3518
3519OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3520 MemRefType sourceMemrefType = getSource().getType();
3521 MemRefType resultMemrefType = getResult().getType();
3522 auto resultLayout =
3523 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3524
3525 if (resultMemrefType == sourceMemrefType &&
3526 resultMemrefType.hasStaticShape() &&
3527 (!resultLayout || resultLayout.hasStaticLayout())) {
3528 return getViewSource();
3529 }
3530
3531 // Fold subview(subview(x)), where both subviews have the same size and the
3532 // second subview's offsets are all zero. (I.e., the second subview is a
3533 // no-op.)
3534 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3535 auto srcSizes = srcSubview.getMixedSizes();
3536 auto sizes = getMixedSizes();
3537 auto offsets = getMixedOffsets();
3538 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3539 auto strides = getMixedStrides();
3540 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3541 bool allSizesSame = llvm::equal(sizes, srcSizes);
3542 if (allOffsetsZero && allStridesOne && allSizesSame &&
3543 resultMemrefType == sourceMemrefType)
3544 return getViewSource();
3545 }
3546
3547 return {};
3548}
3549
3550FailureOr<std::optional<SmallVector<Value>>>
3551SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3552 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3553}
3554
3555void SubViewOp::inferStridedMetadataRanges(
3556 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3557 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3558 auto isUninitialized =
3559 +[](IntegerValueRange range) { return range.isUninitialized(); };
3560
3561 // Bail early if any of the operands metadata is not ready:
3562 SmallVector<IntegerValueRange> offsetOperands =
3563 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3564 if (llvm::any_of(offsetOperands, isUninitialized))
3565 return;
3566
3567 SmallVector<IntegerValueRange> sizeOperands =
3568 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3569 if (llvm::any_of(sizeOperands, isUninitialized))
3570 return;
3571
3572 SmallVector<IntegerValueRange> stridesOperands =
3573 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3574 if (llvm::any_of(stridesOperands, isUninitialized))
3575 return;
3576
3577 StridedMetadataRange sourceRange =
3578 ranges[getSourceMutable().getOperandNumber()];
3579 if (sourceRange.isUninitialized())
3580 return;
3581
3582 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3583
3584 // Get the dropped dims.
3585 llvm::SmallBitVector droppedDims = getDroppedDims();
3586
3587 // Compute the new offset, strides and sizes.
3588 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3589 SmallVector<ConstantIntRanges> strides, sizes;
3590
3591 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3592 bool dropped = droppedDims.test(i);
3593 // Compute the new offset.
3594 ConstantIntRanges off =
3595 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3596 offset = intrange::inferAdd({offset, off});
3597
3598 // Skip dropped dimensions.
3599 if (dropped)
3600 continue;
3601 // Multiply the strides.
3602 strides.push_back(
3603 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3604 // Get the sizes.
3605 sizes.push_back(sizeOperands[i].getValue());
3606 }
3607
3608 setMetadata(getResult(),
3610 SmallVector<ConstantIntRanges>({std::move(offset)}),
3611 std::move(sizes), std::move(strides)));
3612}
3613
3614//===----------------------------------------------------------------------===//
3615// TransposeOp
3616//===----------------------------------------------------------------------===//
3617
3618void TransposeOp::getAsmResultNames(
3619 function_ref<void(Value, StringRef)> setNameFn) {
3620 setNameFn(getResult(), "transpose");
3621}
3622
3623/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3624static MemRefType inferTransposeResultType(MemRefType memRefType,
3625 AffineMap permutationMap) {
3626 auto originalSizes = memRefType.getShape();
3627 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3628 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3629
3630 // Compute permuted sizes and strides.
3631 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3632 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3633
3634 return MemRefType::Builder(memRefType)
3635 .setShape(sizes)
3636 .setLayout(
3637 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3638}
3639
3640void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3641 AffineMapAttr permutation,
3642 ArrayRef<NamedAttribute> attrs) {
3643 auto permutationMap = permutation.getValue();
3644 assert(permutationMap);
3645
3646 auto memRefType = llvm::cast<MemRefType>(in.getType());
3647 // Compute result type.
3648 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3649
3650 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3651 build(b, result, resultType, in, attrs);
3652}
3653
3654// transpose $in $permutation attr-dict : type($in) `to` type(results)
3655void TransposeOp::print(OpAsmPrinter &p) {
3656 p << " " << getIn() << " " << getPermutation();
3657 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3658 p << " : " << getIn().getType() << " to " << getType();
3659}
3660
3661ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3662 OpAsmParser::UnresolvedOperand in;
3663 AffineMap permutation;
3664 MemRefType srcType, dstType;
3665 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3666 parser.parseOptionalAttrDict(result.attributes) ||
3667 parser.parseColonType(srcType) ||
3668 parser.resolveOperand(in, srcType, result.operands) ||
3669 parser.parseKeywordType("to", dstType) ||
3670 parser.addTypeToList(dstType, result.types))
3671 return failure();
3672
3673 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3674 AffineMapAttr::get(permutation));
3675 return success();
3676}
3677
3678LogicalResult TransposeOp::verify() {
3679 if (!getPermutation().isPermutation())
3680 return emitOpError("expected a permutation map");
3681 if (getPermutation().getNumDims() != getIn().getType().getRank())
3682 return emitOpError("expected a permutation map of same rank as the input");
3683
3684 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3685 auto resultType = llvm::cast<MemRefType>(getType());
3686 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3687 .canonicalizeStridedLayout();
3688
3689 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3690 return emitOpError("result type ")
3691 << resultType
3692 << " is not equivalent to the canonical transposed input type "
3693 << canonicalResultType;
3694 return success();
3695}
3696
3697OpFoldResult TransposeOp::fold(FoldAdaptor) {
3698 // First check for identity permutation, we can fold it away if input and
3699 // result types are identical already.
3700 if (getPermutation().isIdentity() && getType() == getIn().getType())
3701 return getIn();
3702 // Fold two consecutive memref.transpose Ops into one by composing their
3703 // permutation maps.
3704 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3705 AffineMap composedPermutation =
3706 getPermutation().compose(otherTransposeOp.getPermutation());
3707 getInMutable().assign(otherTransposeOp.getIn());
3708 setPermutation(composedPermutation);
3709 return getResult();
3710 }
3711 return {};
3712}
3713
3714FailureOr<std::optional<SmallVector<Value>>>
3715TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3716 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3717}
3718
3719//===----------------------------------------------------------------------===//
3720// ViewOp
3721//===----------------------------------------------------------------------===//
3722
3723void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3724 setNameFn(getResult(), "view");
3725}
3726
3727LogicalResult ViewOp::verify() {
3728 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3729 auto viewType = getType();
3730
3731 // The base memref should have identity layout map (or none).
3732 if (!baseType.getLayout().isIdentity())
3733 return emitError("unsupported map for base memref type ") << baseType;
3734
3735 // The result memref should have identity layout map (or none).
3736 if (!viewType.getLayout().isIdentity())
3737 return emitError("unsupported map for result memref type ") << viewType;
3738
3739 // The base memref and the view memref should be in the same memory space.
3740 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3741 return emitError("different memory spaces specified for base memref "
3742 "type ")
3743 << baseType << " and view memref type " << viewType;
3744
3745 // Verify that we have the correct number of sizes for the result type.
3746 if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
3747 return failure();
3748
3749 return success();
3750}
3751
3752Value ViewOp::getViewSource() { return getSource(); }
3753
3754OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3755 MemRefType sourceMemrefType = getSource().getType();
3756 MemRefType resultMemrefType = getResult().getType();
3757
3758 if (resultMemrefType == sourceMemrefType &&
3759 resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
3760 return getViewSource();
3761
3762 return {};
3763}
3764
3765namespace {
3766
3767struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3768 using Base::Base;
3769
3770 LogicalResult matchAndRewrite(ViewOp viewOp,
3771 PatternRewriter &rewriter) const override {
3772 // Return if none of the operands are constants.
3773 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3774 return matchPattern(operand, matchConstantIndex());
3775 }))
3776 return failure();
3777
3778 // Get result memref type.
3779 auto memrefType = viewOp.getType();
3780
3781 // Get offset from old memref view type 'memRefType'.
3782 int64_t oldOffset;
3783 SmallVector<int64_t, 4> oldStrides;
3784 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3785 return failure();
3786 assert(oldOffset == 0 && "Expected 0 offset");
3787
3788 SmallVector<Value, 4> newOperands;
3789
3790 // Offset cannot be folded into result type.
3791
3792 // Fold any dynamic dim operands which are produced by a constant.
3793 SmallVector<int64_t, 4> newShapeConstants;
3794 newShapeConstants.reserve(memrefType.getRank());
3795
3796 unsigned dynamicDimPos = 0;
3797 unsigned rank = memrefType.getRank();
3798 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3799 int64_t dimSize = memrefType.getDimSize(dim);
3800 // If this is already static dimension, keep it.
3801 if (ShapedType::isStatic(dimSize)) {
3802 newShapeConstants.push_back(dimSize);
3803 continue;
3804 }
3805 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3806 if (auto constantIndexOp =
3807 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3808 // Dynamic shape dimension will be folded.
3809 newShapeConstants.push_back(constantIndexOp.value());
3810 } else {
3811 // Dynamic shape dimension not folded; copy operand from old memref.
3812 newShapeConstants.push_back(dimSize);
3813 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3814 }
3815 dynamicDimPos++;
3816 }
3817
3818 // Create new memref type with constant folded dims.
3819 MemRefType newMemRefType =
3820 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3821 // Nothing new, don't fold.
3822 if (newMemRefType == memrefType)
3823 return failure();
3824
3825 // Create new ViewOp.
3826 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3827 viewOp.getOperand(0), viewOp.getByteShift(),
3828 newOperands);
3829 // Insert a cast so we have the same type as the old memref type.
3830 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3831 return success();
3832 }
3833};
3834
3835/// view(memref.cast(%source)) -> view(%source).
3836struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3837 using Base::Base;
3838
3839 LogicalResult matchAndRewrite(ViewOp viewOp,
3840 PatternRewriter &rewriter) const override {
3841 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3842 if (!memrefCastOp)
3843 return failure();
3844
3845 rewriter.replaceOpWithNewOp<ViewOp>(
3846 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3847 viewOp.getByteShift(), viewOp.getSizes());
3848 return success();
3849 }
3850};
3851} // namespace
3852
3853void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3854 MLIRContext *context) {
3855 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3856}
3857
3858FailureOr<std::optional<SmallVector<Value>>>
3859ViewOp::bubbleDownCasts(OpBuilder &builder) {
3860 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3861}
3862
3863//===----------------------------------------------------------------------===//
3864// AtomicRMWOp
3865//===----------------------------------------------------------------------===//
3866
3867LogicalResult AtomicRMWOp::verify() {
3868 if (getMemRefType().getRank() != getNumOperands() - 2)
3869 return emitOpError(
3870 "expects the number of subscripts to be equal to memref rank");
3871 switch (getKind()) {
3872 case arith::AtomicRMWKind::addf:
3873 case arith::AtomicRMWKind::maximumf:
3874 case arith::AtomicRMWKind::minimumf:
3875 case arith::AtomicRMWKind::mulf:
3876 if (!llvm::isa<FloatType>(getValue().getType()))
3877 return emitOpError() << "with kind '"
3878 << arith::stringifyAtomicRMWKind(getKind())
3879 << "' expects a floating-point type";
3880 break;
3881 case arith::AtomicRMWKind::addi:
3882 case arith::AtomicRMWKind::maxs:
3883 case arith::AtomicRMWKind::maxu:
3884 case arith::AtomicRMWKind::mins:
3885 case arith::AtomicRMWKind::minu:
3886 case arith::AtomicRMWKind::muli:
3887 case arith::AtomicRMWKind::ori:
3888 case arith::AtomicRMWKind::xori:
3889 case arith::AtomicRMWKind::andi:
3890 if (!llvm::isa<IntegerType>(getValue().getType()))
3891 return emitOpError() << "with kind '"
3892 << arith::stringifyAtomicRMWKind(getKind())
3893 << "' expects an integer type";
3894 break;
3895 default:
3896 break;
3897 }
3898 return success();
3899}
3900
3901OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3902 /// atomicrmw(memrefcast) -> atomicrmw
3903 if (succeeded(foldMemRefCast(*this, getValue())))
3904 return getResult();
3905 return OpFoldResult();
3906}
3907
3908FailureOr<std::optional<SmallVector<Value>>>
3909AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3911 getResult());
3912}
3913
3914//===----------------------------------------------------------------------===//
3915// TableGen'd op method definitions
3916//===----------------------------------------------------------------------===//
3917
3918#define GET_OP_CLASSES
3919#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:69
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:97
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IndexType getIndexType()
Definition Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:61
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:69
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:46
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:78
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:67
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.