MLIR 23.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28
29using namespace mlir;
30using namespace mlir::memref;
31
32/// Materialize a single constant operation from a given attribute value with
33/// the desired resultant type.
34Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
35 Attribute value, Type type,
36 Location loc) {
37 return arith::ConstantOp::materialize(builder, value, type, loc);
38}
39
40//===----------------------------------------------------------------------===//
41// Common canonicalization pattern support logic
42//===----------------------------------------------------------------------===//
43
44/// This is a common class used for patterns of the form
45/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
46/// into the root operation directly.
47LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
48 bool folded = false;
49 for (OpOperand &operand : op->getOpOperands()) {
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
54 folded = true;
55 }
56 }
57 return success(folded);
58}
59
60/// Return an unranked/ranked tensor type for the given unranked/ranked memref
61/// type.
63 if (auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(memref.getShape(), memref.getElementType());
65 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(memref.getElementType());
67 return NoneType::get(type.getContext());
68}
69
71 int64_t dim) {
72 auto memrefType = llvm::cast<MemRefType>(value.getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.createOrFold<memref::DimOp>(loc, value, dim);
75
76 return builder.getIndexAttr(memrefType.getDimSize(dim));
77}
78
80 Location loc, Value value) {
81 auto memrefType = llvm::cast<MemRefType>(value.getType());
83 for (int64_t i = 0; i < memrefType.getRank(); ++i)
84 result.push_back(getMixedSize(builder, loc, value, i));
85 return result;
86}
87
88//===----------------------------------------------------------------------===//
89// Utility functions for propagating static information
90//===----------------------------------------------------------------------===//
91
92/// Helper function that sets values[i] to constValues[i] if the latter is a
93/// static value, as indicated by ShapedType::kDynamic.
94///
95/// If constValues[i] is dynamic, tries to extract a constant value from
96/// value[i] to allow for additional folding opportunities. Also convertes all
97/// existing attributes to index attributes. (They may be i64 attributes.)
99 ArrayRef<int64_t> constValues) {
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
103 Builder builder(values[i].getContext());
104 if (ShapedType::isStatic(cstVal)) {
105 // Constant value is known, use it directly.
106 values[i] = builder.getIndexAttr(cstVal);
107 continue;
108 }
109 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
110 // Try to extract a constant or convert an existing to index.
111 values[i] = builder.getIndexAttr(*cst);
112 }
113 }
114}
115
116/// Helper function to retrieve a lossless memory-space cast, and the
117/// corresponding new result memref type.
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
122
123 // Bail if the cast is not lossless.
124 if (!castOp)
125 return {};
126
127 // Transform the source and target type of `castOp` to have the same metadata
128 // as `resultTy`. Bail if not possible.
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
131 if (failed(srcTy))
132 return {};
133
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
136 if (failed(tgtTy))
137 return {};
138
139 // Check if this is a valid memory-space cast.
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return {};
142
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
144}
145
146/// Implementation of `bubbleDownCasts` method for memref operations that
147/// return a single memref result.
148template <typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
151 OpOperand &src) {
152 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
153 // Bail if we cannot cast.
154 if (!castOp)
155 return failure();
156
157 // Create the new operands.
158 SmallVector<Value> operands;
159 llvm::append_range(operands, op->getOperands());
160 operands[src.getOperandNumber()] = castOp.getSourcePtr();
161
162 // Create the new op and results.
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166
167 // Insert a memory-space cast to the original memory space of the op.
168 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 builder, tgtTy,
170 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
171 return std::optional<SmallVector<Value>>(
172 SmallVector<Value>({result.getTargetPtr()}));
173}
174
175//===----------------------------------------------------------------------===//
176// AllocOp / AllocaOp
177//===----------------------------------------------------------------------===//
178
179void AllocOp::getAsmResultNames(
180 function_ref<void(Value, StringRef)> setNameFn) {
181 setNameFn(getResult(), "alloc");
182}
183
184void AllocaOp::getAsmResultNames(
185 function_ref<void(Value, StringRef)> setNameFn) {
186 setNameFn(getResult(), "alloca");
187}
188
189template <typename AllocLikeOp>
190static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
194 if (!memRefType)
195 return op.emitOpError("result must be a memref");
196
197 if (failed(verifyDynamicDimensionCount(op, memRefType, op.getDynamicSizes())))
198 return failure();
199
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError("symbol operand count does not equal memref symbol "
205 "count: expected ")
206 << numSymbols << ", got " << op.getSymbolOperands().size();
207
208 return success();
209}
210
211LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
212
213LogicalResult AllocaOp::verify() {
214 // An alloca op needs to have an ancestor with an allocation scope trait.
215 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
216 return emitOpError(
217 "requires an ancestor op with AutomaticAllocationScope trait");
218
219 return verifyAllocLikeOp(*this);
220}
221
222namespace {
223/// Fold constant dimensions into an alloc like operation.
224template <typename AllocLikeOp>
225struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
227
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter) const override {
230 // Check to see if any dimensions operands are constants. If so, we can
231 // substitute and drop them.
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
233 APInt constSizeArg;
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return false;
236 return constSizeArg.isNonNegative();
237 }))
238 return failure();
239
240 auto memrefType = alloc.getType();
241
242 // Ok, we have one or more constant operands. Collect the non-constant ones
243 // and keep track of the resultant memref type to build.
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
247
248 unsigned dynamicDimPos = 0;
249 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
251 // If this is already static dimension, keep it.
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
254 continue;
255 }
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
257 APInt constSizeArg;
258 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
259 constSizeArg.isNonNegative()) {
260 // Dynamic shape dimension will be folded.
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
262 } else {
263 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
266 }
267 dynamicDimPos++;
268 }
269
270 // Create new memref type (which will have fewer dynamic dimensions).
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
274
275 // Create and insert the alloc op for the new memref.
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
279 // Insert a cast so we have the same type as the old alloc.
280 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
281 return success();
282 }
283};
284
285/// Fold alloc operations with no users or only store and dealloc uses.
286template <typename T>
287struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 using OpRewritePattern<T>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter) const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
296 }))
297 return failure();
298
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
300 rewriter.eraseOp(user);
301
302 rewriter.eraseOp(alloc);
303 return success();
304 }
305};
306} // namespace
307
308void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
311}
312
313void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
314 MLIRContext *context) {
315 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
316 context);
317}
318
319//===----------------------------------------------------------------------===//
320// ReallocOp
321//===----------------------------------------------------------------------===//
322
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
325 MemRefType resultType = getType();
326
327 // The source memref should have identity layout (or none).
328 if (!sourceType.getLayout().isIdentity())
329 return emitError("unsupported layout for source memref type ")
330 << sourceType;
331
332 // The result memref should have identity layout (or none).
333 if (!resultType.getLayout().isIdentity())
334 return emitError("unsupported layout for result memref type ")
335 << resultType;
336
337 // The source memref and the result memref should be in the same memory space.
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError("different memory spaces specified for source memref "
340 "type ")
341 << sourceType << " and result memref type " << resultType;
342
343 // The source memref and the result memref should have the same element type.
344 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
345 "result")))
346 return failure();
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor::parent());
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) {
416 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
417}
418
419/// Given an operation, return whether this op is guaranteed to
420/// allocate an AutomaticAllocationScopeResource
422 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
423 if (!interface)
424 return false;
425 for (auto res : op->getResults()) {
426 if (auto effect =
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
430 return true;
431 }
432 }
433 return false;
434}
435
436/// Given an operation, return whether this op itself could
437/// allocate an AutomaticAllocationScopeResource. Note that
438/// this will not check whether an operation contained within
439/// the op can allocate.
441 // This op itself doesn't create a stack allocation,
442 // the inner allocation should be handled separately.
444 return false;
445 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
446 if (!interface)
447 return true;
448 for (auto res : op->getResults()) {
449 if (auto effect =
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
453 return true;
454 }
455 }
456 return false;
457}
458
459/// Return whether this op is the last non terminating op
460/// in a region. That is to say, it is in a one-block region
461/// and is only followed by a terminator. This prevents
462/// extending the lifetime of allocations.
464 return op->getBlock()->mightHaveTerminator() &&
465 op->getNextNode() == op->getBlock()->getTerminator() &&
467}
468
469/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
470/// or it contains no allocation.
471struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
472 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
473
474 LogicalResult matchAndRewrite(AllocaScopeOp op,
475 PatternRewriter &rewriter) const override {
476 bool hasPotentialAlloca =
477 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
478 if (alloc == op)
479 return WalkResult::advance();
481 return WalkResult::interrupt();
482 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
483 return WalkResult::skip();
484 return WalkResult::advance();
485 }).wasInterrupted();
486
487 // If this contains no potential allocation, it is always legal to
488 // inline. Otherwise, consider two conditions:
489 if (hasPotentialAlloca) {
490 // If the parent isn't an allocation scope, or we are not the last
491 // non-terminator op in the parent, we will extend the lifetime.
492 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return failure();
495 return failure();
496 }
497
498 Block *block = &op.getRegion().front();
499 Operation *terminator = block->getTerminator();
500 ValueRange results = terminator->getOperands();
501 rewriter.inlineBlockBefore(block, op);
502 rewriter.replaceOp(op, results);
503 rewriter.eraseOp(terminator);
504 return success();
505 }
506};
507
508/// Move allocations into an allocation scope, if it is legal to
509/// move them (e.g. their operands are available at the location
510/// the op would be moved to).
511struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
512 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
513
514 LogicalResult matchAndRewrite(AllocaScopeOp op,
515 PatternRewriter &rewriter) const override {
516
517 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
518 return failure();
519
520 Operation *lastParentWithoutScope = op->getParentOp();
521
522 if (!lastParentWithoutScope ||
523 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
524 return failure();
525
526 // Only apply to if this is this last non-terminator
527 // op in the block (lest lifetime be extended) of a one
528 // block region
529 if (!lastNonTerminatorInRegion(op) ||
530 !lastNonTerminatorInRegion(lastParentWithoutScope))
531 return failure();
532
533 while (!lastParentWithoutScope->getParentOp()
535 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
536 if (!lastParentWithoutScope ||
537 !lastNonTerminatorInRegion(lastParentWithoutScope))
538 return failure();
539 }
540 assert(lastParentWithoutScope->getParentOp()
542
543 Region *containingRegion = nullptr;
544 for (auto &r : lastParentWithoutScope->getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion == nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
549 }
550 }
551 assert(containingRegion && "op must be contained in a region");
552
554 op->walk([&](Operation *alloc) {
556 return WalkResult::skip();
557
558 // If any operand is not defined before the location of
559 // lastParentWithoutScope (i.e. where we would hoist to), skip.
560 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
561 return containingRegion->isAncestor(v.getParentRegion());
562 }))
563 return WalkResult::skip();
564 toHoist.push_back(alloc);
565 return WalkResult::advance();
566 });
567
568 if (toHoist.empty())
569 return failure();
570 rewriter.setInsertionPoint(lastParentWithoutScope);
571 for (auto *op : toHoist) {
572 auto *cloned = rewriter.clone(*op);
573 rewriter.replaceOp(op, cloned->getResults());
574 }
575 return success();
576 }
577};
578
579void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
580 MLIRContext *context) {
581 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
582}
583
584//===----------------------------------------------------------------------===//
585// AssumeAlignmentOp
586//===----------------------------------------------------------------------===//
587
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError("alignment must be power of 2");
591 return success();
592}
593
594void AssumeAlignmentOp::getAsmResultNames(
595 function_ref<void(Value, StringRef)> setNameFn) {
596 setNameFn(getResult(), "assume_align");
597}
598
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
601 if (!source)
602 return {};
603 if (source.getAlignment() != getAlignment())
604 return {};
605 return getMemref();
606}
607
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
610 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
611}
612
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
614 int resultIndex,
615 int dim) {
616 assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
618}
619
620//===----------------------------------------------------------------------===//
621// DistinctObjectsOp
622//===----------------------------------------------------------------------===//
623
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError("operand types and result types must match");
627
628 if (getOperandTypes().empty())
629 return emitOpError("expected at least one operand");
630
631 return success();
632}
633
634LogicalResult DistinctObjectsOp::inferReturnTypes(
635 MLIRContext * /*context*/, std::optional<Location> /*location*/,
636 ValueRange operands, DictionaryAttr /*attributes*/,
637 PropertyRef /*properties*/, RegionRange /*regions*/,
638 SmallVectorImpl<Type> &inferredReturnTypes) {
639 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// CastOp
645//===----------------------------------------------------------------------===//
646
647void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
648 setNameFn(getResult(), "cast");
649}
650
651/// Determines whether MemRef_CastOp casts to a more dynamic version of the
652/// source memref. This is useful to fold a memref.cast into a consuming op
653/// and implement canonicalization patterns for ops in different dialects that
654/// may consume the results of memref.cast operations. Such foldable memref.cast
655/// operations are typically inserted as `view` and `subview` ops are
656/// canonicalized, to preserve the type compatibility of their uses.
657///
658/// Returns true when all conditions are met:
659/// 1. source and result are ranked memrefs with strided semantics and same
660/// element type and rank.
661/// 2. each of the source's size, offset or stride has more static information
662/// than the corresponding result's size, offset or stride.
663///
664/// Example 1:
665/// ```mlir
666/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
667/// %2 = consumer %1 ... : memref<?x?xf32> ...
668/// ```
669///
670/// may fold into:
671///
672/// ```mlir
673/// %2 = consumer %0 ... : memref<8x16xf32> ...
674/// ```
675///
676/// Example 2:
677/// ```
678/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
679/// to memref<?x?xf32>
680/// consumer %1 : memref<?x?xf32> ...
681/// ```
682///
683/// may fold into:
684///
685/// ```
686/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
687/// ```
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692
693 // Requires ranked MemRefType.
694 if (!sourceType || !resultType)
695 return false;
696
697 // Requires same elemental type.
698 if (sourceType.getElementType() != resultType.getElementType())
699 return false;
700
701 // Requires same rank.
702 if (sourceType.getRank() != resultType.getRank())
703 return false;
704
705 // Only fold casts between strided memref forms.
706 int64_t sourceOffset, resultOffset;
707 SmallVector<int64_t, 4> sourceStrides, resultStrides;
708 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
710 return false;
711
712 // If cast is towards more static sizes along any dimension, don't fold.
713 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
715 if (ss != st)
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
717 return false;
718 }
719
720 // If cast is towards more static offset along any dimension, don't fold.
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
724 return false;
725
726 // If cast is towards more static strides along any dimension, don't fold.
727 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
729 if (ss != st)
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
731 return false;
732 }
733
734 return true;
735}
736
737bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
738 if (inputs.size() != 1 || outputs.size() != 1)
739 return false;
740 Type a = inputs.front(), b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(b);
743
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
746
747 if (aT && bT) {
748 if (aT.getElementType() != bT.getElementType())
749 return false;
750 if (aT.getLayout() != bT.getLayout()) {
751 int64_t aOffset, bOffset;
752 SmallVector<int64_t, 4> aStrides, bStrides;
753 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
756 return false;
757
758 // Strides along a dimension/offset are compatible if the value in the
759 // source memref is static and the value in the target memref is the
760 // same. They are also compatible if either one is dynamic (see
761 // description of MemRefCastOp for details).
762 // Note that for dimensions of size 1, the stride can differ.
763 auto checkCompatible = [](int64_t a, int64_t b) {
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
765 };
766 if (!checkCompatible(aOffset, bOffset))
767 return false;
768 for (const auto &[index, aStride] : enumerate(aStrides)) {
769 if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
770 continue;
771 if (!checkCompatible(aStride, bStrides[index]))
772 return false;
773 }
774 }
775 if (aT.getMemorySpace() != bT.getMemorySpace())
776 return false;
777
778 // They must have the same rank, and any specified dimensions must match.
779 if (aT.getRank() != bT.getRank())
780 return false;
781
782 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
785 aDim != bDim)
786 return false;
787 }
788 return true;
789 } else {
790 if (!aT && !uaT)
791 return false;
792 if (!bT && !ubT)
793 return false;
794 // Unranked to unranked casting is unsupported
795 if (uaT && ubT)
796 return false;
797
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
801 return false;
802
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
806 }
807
808 return false;
809}
810
811OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
812 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
813}
814
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(OpBuilder &builder) {
817 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
818}
819
820//===----------------------------------------------------------------------===//
821// CopyOp
822//===----------------------------------------------------------------------===//
823
824namespace {
825
826/// Fold memref.copy(%x, %x).
827struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
829
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter) const override {
832 if (copyOp.getSource() != copyOp.getTarget())
833 return failure();
834
835 rewriter.eraseOp(copyOp);
836 return success();
837 }
838};
839
840struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
842
843 static bool isEmptyMemRef(BaseMemRefType type) {
844 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
845 }
846
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter) const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
851 rewriter.eraseOp(copyOp);
852 return success();
853 }
854
855 return failure();
856 }
857};
858} // namespace
859
860void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
861 MLIRContext *context) {
862 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
863}
864
865/// If the source/target of a CopyOp is a CastOp that does not modify the shape
866/// and element type, the cast can be skipped. Such CastOps only cast the layout
867/// of the type.
868static LogicalResult foldCopyOfCast(CopyOp op) {
869 for (OpOperand &operand : op->getOpOperands()) {
870 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
873 return success();
874 }
875 }
876 return failure();
877}
878
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
880 SmallVectorImpl<OpFoldResult> &results) {
881
882 /// copy(memrefcast) -> copy
883 return foldCopyOfCast(*this);
884}
885
886//===----------------------------------------------------------------------===//
887// DeallocOp
888//===----------------------------------------------------------------------===//
889
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
891 SmallVectorImpl<OpFoldResult> &results) {
892 /// dealloc(memrefcast) -> dealloc
893 return foldMemRefCast(*this);
894}
895
896//===----------------------------------------------------------------------===//
897// DimOp
898//===----------------------------------------------------------------------===//
899
900void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(), "dim");
902}
903
904void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
905 int64_t index) {
906 auto loc = result.location;
907 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
908 build(builder, result, source, indexValue);
909}
910
911std::optional<int64_t> DimOp::getConstantIndex() {
913}
914
915Speculation::Speculatability DimOp::getSpeculatability() {
916 auto constantIndex = getConstantIndex();
917 if (!constantIndex)
919
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
921 if (!rankedSourceType)
923
924 if (rankedSourceType.getRank() <= constantIndex)
926
928}
929
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
931 SetIntLatticeFn setResultRange) {
932 setResultRange(getResult(),
933 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
934}
935
936/// Return a map with key being elements in `vals` and data being number of
937/// occurences of it. Use std::map, since the `vals` here are strides and the
938/// dynamic stride value is the same as the tombstone value for
939/// `DenseMap<int64_t>`.
940static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
941 std::map<int64_t, unsigned> numOccurences;
942 for (auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
945}
946
947/// Returns the set of source dimensions that are dropped in a rank reduction.
948/// For each result dimension in order, matches the leftmost unmatched source
949/// dimension with the same size. Source dimensions not matched are dropped.
950///
951/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
952/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
953/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
954static FailureOr<llvm::SmallBitVector>
956 MemRefType reducedType,
958 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
959 if (rankReduction <= 0)
960 return llvm::SmallBitVector(originalType.getRank());
961
962 // Build source sizes from subview sizes (one per source dim).
963 SmallVector<int64_t> sourceSizes(originalType.getRank());
964 for (const auto &it : llvm::enumerate(sizes)) {
965 if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
966 sourceSizes[it.index()] = *cst;
967 else
968 sourceSizes[it.index()] = ShapedType::kDynamic;
969 }
970
971 ArrayRef<int64_t> resultSizes = reducedType.getShape();
972 llvm::SmallBitVector usedSourceDims(originalType.getRank());
973 int64_t startJ = 0;
974 for (int64_t resultSize : resultSizes) {
975 bool matched = false;
976 for (int64_t j = startJ; j < originalType.getRank(); ++j) {
977 if (sourceSizes[j] == resultSize) {
978 usedSourceDims.set(j);
979 matched = true;
980 startJ = j + 1;
981 break;
982 }
983 }
984 if (!matched)
985 return failure();
986 }
987
988 llvm::SmallBitVector unusedDims(originalType.getRank());
989 for (int64_t i = 0; i < originalType.getRank(); ++i)
990 if (!usedSourceDims.test(i))
991 unusedDims.set(i);
992 return unusedDims;
993}
994
995/// Returns the set of source dimensions that are dropped in a rank reduction.
996/// A dimension is dropped if its stride is dropped; uses stride occurrence
997/// counting to disambiguate when multiple unit dims exist.
998///
999/// Example: memref<1x1x?xf32, strided<[?, 4, 1]>> to memref<1x4xf32,
1000/// strided<[4, 1]>>. Source strides [?, 4, 1], candidate [4, 1]. Dim 0 (stride
1001/// ?) can be dropped; dim 1 (stride 4) must be kept. Source dim 0 is dropped.
1002static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
1003 MemRefType originalType, MemRefType reducedType,
1004 ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
1005 llvm::SmallBitVector unusedDims) {
1006 // Track the number of occurences of the strides in the original type
1007 // and the candidate type. For each unused dim that stride should not be
1008 // present in the candidate type. Note that there could be multiple dimensions
1009 // that have the same size. We dont need to exactly figure out which dim
1010 // corresponds to which stride, we just need to verify that the number of
1011 // reptitions of a stride in the original + number of unused dims with that
1012 // stride == number of repititions of a stride in the candidate.
1013 std::map<int64_t, unsigned> currUnaccountedStrides =
1014 getNumOccurences(originalStrides);
1015 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1016 getNumOccurences(candidateStrides);
1017 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1018 if (!unusedDims.test(dim))
1019 continue;
1020 int64_t originalStride = originalStrides[dim];
1021 if (currUnaccountedStrides[originalStride] >
1022 candidateStridesNumOccurences[originalStride]) {
1023 // This dim can be treated as dropped.
1024 currUnaccountedStrides[originalStride]--;
1025 continue;
1026 }
1027 if (currUnaccountedStrides[originalStride] ==
1028 candidateStridesNumOccurences[originalStride]) {
1029 // The stride for this is not dropped. Keep as is.
1030 unusedDims.reset(dim);
1031 continue;
1032 }
1033 if (currUnaccountedStrides[originalStride] <
1034 candidateStridesNumOccurences[originalStride]) {
1035 // This should never happen. Cant have a stride in the reduced rank type
1036 // that wasnt in the original one.
1037 return failure();
1038 }
1039 }
1040 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1041 originalType.getRank())
1042 return failure();
1043 return unusedDims;
1044}
1045
1046/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
1047/// to be a subset of `originalType` with some `1` entries erased, return the
1048/// set of indices that specifies which of the entries of `originalShape` are
1049/// dropped to obtain `reducedShape`.
1050/// This accounts for cases where there are multiple unit-dims, but only a
1051/// subset of those are dropped. For MemRefTypes these can be disambiguated
1052/// using the strides. If a dimension is dropped the stride must be dropped too.
1053static FailureOr<llvm::SmallBitVector>
1054computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
1055 ArrayRef<OpFoldResult> sizes) {
1056 llvm::SmallBitVector unusedDims(originalType.getRank());
1057 if (originalType.getRank() == reducedType.getRank())
1058 return unusedDims;
1059
1060 for (const auto &dim : llvm::enumerate(sizes))
1061 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1062 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1063 unusedDims.set(dim.index());
1064
1065 // Early exit for the case where the number of unused dims matches the number
1066 // of ranks reduced.
1067 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1068 originalType.getRank())
1069 return unusedDims;
1070
1071 SmallVector<int64_t> originalStrides, candidateStrides;
1072 int64_t originalOffset, candidateOffset;
1073 if (failed(
1074 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1075 failed(
1076 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1077 return failure();
1078
1079 // Try stride-based first when we have meaningful static stride info
1080 // (preserves static strides). Fall back to position-based otherwise.
1081 auto hasNonTrivialStaticStride = [](ArrayRef<int64_t> strides) {
1082 // The innermost stride 1 is trivial for row-major and does not help
1083 // disambiguate.
1084 if (strides.size() <= 1)
1085 return false;
1086 return llvm::any_of(strides.drop_back(),
1087 [](int64_t s) { return !ShapedType::isDynamic(s); });
1088 };
1089 if (hasNonTrivialStaticStride(originalStrides) ||
1090 hasNonTrivialStaticStride(candidateStrides)) {
1091 FailureOr<llvm::SmallBitVector> strideBased =
1092 computeMemRefRankReductionMaskByStrides(originalType, reducedType,
1093 originalStrides,
1094 candidateStrides, unusedDims);
1095 if (succeeded(strideBased))
1096 return *strideBased;
1097 }
1098 return computeMemRefRankReductionMaskByPosition(originalType, reducedType,
1099 sizes);
1100}
1101
1102llvm::SmallBitVector SubViewOp::getDroppedDims() {
1103 MemRefType sourceType = getSourceType();
1104 MemRefType resultType = getType();
1105 FailureOr<llvm::SmallBitVector> unusedDims =
1106 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1107 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1108 return *unusedDims;
1109}
1110
1111OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1112 // All forms of folding require a known index.
1113 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1114 if (!index)
1115 return {};
1116
1117 // Folding for unranked types (UnrankedMemRefType) is not supported.
1118 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1119 if (!memrefType)
1120 return {};
1121
1122 // Out of bound indices produce undefined behavior but are still valid IR.
1123 // Don't choke on them.
1124 int64_t indexVal = index.getInt();
1125 if (indexVal < 0 || indexVal >= memrefType.getRank())
1126 return {};
1127
1128 // Fold if the shape extent along the given index is known.
1129 if (!memrefType.isDynamicDim(index.getInt())) {
1130 Builder builder(getContext());
1131 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1132 }
1133
1134 // The size at the given index is now known to be a dynamic size.
1135 unsigned unsignedIndex = index.getValue().getZExtValue();
1136
1137 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1138 Operation *definingOp = getSource().getDefiningOp();
1139
1140 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1141 return *(alloc.getDynamicSizes().begin() +
1142 memrefType.getDynamicDimIndex(unsignedIndex));
1143
1144 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1145 return *(alloca.getDynamicSizes().begin() +
1146 memrefType.getDynamicDimIndex(unsignedIndex));
1147
1148 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1149 return *(view.getDynamicSizes().begin() +
1150 memrefType.getDynamicDimIndex(unsignedIndex));
1151
1152 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1153 // The result dim is dynamic (the static case was handled above). Dropped
1154 // dims always have static size 1, so dynamic source sizes are never
1155 // dropped and map in order to the dynamic result dims. Find the k-th
1156 // dynamic source size, where k is the dynamic dim index of the result dim.
1157 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1158 unsigned dynamicIdx = 0;
1159 for (OpFoldResult size : subview.getMixedSizes()) {
1160 if (llvm::isa<Attribute>(size))
1161 continue;
1162 if (dynamicIdx == dynamicResultDimIdx)
1163 return size;
1164 dynamicIdx++;
1165 }
1166 return {};
1167 }
1168
1169 // dim(memrefcast) -> dim
1170 if (succeeded(foldMemRefCast(*this)))
1171 return getResult();
1172
1173 return {};
1174}
1175
1176namespace {
1177/// Fold dim of a memref reshape operation to a load into the reshape's shape
1178/// operand.
1179struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1180 using OpRewritePattern<DimOp>::OpRewritePattern;
1181
1182 LogicalResult matchAndRewrite(DimOp dim,
1183 PatternRewriter &rewriter) const override {
1184 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1185
1186 if (!reshape)
1187 return rewriter.notifyMatchFailure(
1188 dim, "Dim op is not defined by a reshape op.");
1189
1190 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1191 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1192 // cheaply check that either of the following conditions hold:
1193 // 1. dim.getIndex() is defined in the same block as reshape but before
1194 // reshape.
1195 // 2. dim.getIndex() is defined in a parent block of
1196 // reshape.
1197
1198 // Check condition 1
1199 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1200 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1201 if (reshape->isBeforeInBlock(definingOp)) {
1202 return rewriter.notifyMatchFailure(
1203 dim,
1204 "dim.getIndex is not defined before reshape in the same block.");
1205 }
1206 } // else dim.getIndex is a block argument to reshape->getBlock and
1207 // dominates reshape
1208 } // Check condition 2
1209 else if (dim->getBlock() != reshape->getBlock() &&
1210 !dim.getIndex().getParentRegion()->isProperAncestor(
1211 reshape->getParentRegion())) {
1212 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1213 // already know dim.getIndex() dominates reshape without calling
1214 // `isProperAncestor`
1215 return rewriter.notifyMatchFailure(
1216 dim, "dim.getIndex does not dominate reshape.");
1217 }
1218
1219 // Place the load directly after the reshape to ensure that the shape memref
1220 // was not mutated.
1221 rewriter.setInsertionPointAfter(reshape);
1222 Location loc = dim.getLoc();
1223 Value load =
1224 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1225 if (load.getType() != dim.getType())
1226 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1227 rewriter.replaceOp(dim, load);
1228 return success();
1229 }
1230};
1231
1232} // namespace
1233
1234void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1235 MLIRContext *context) {
1236 results.add<DimOfMemRefReshape>(context);
1237}
1238
1239// ---------------------------------------------------------------------------
1240// DmaStartOp
1241// ---------------------------------------------------------------------------
1242
1243void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1244 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1245 ValueRange destIndices, Value numElements,
1246 Value tagMemRef, ValueRange tagIndices, Value stride,
1247 Value elementsPerStride) {
1248 result.addOperands(srcMemRef);
1249 result.addOperands(srcIndices);
1250 result.addOperands(destMemRef);
1251 result.addOperands(destIndices);
1252 result.addOperands({numElements, tagMemRef});
1253 result.addOperands(tagIndices);
1254 if (stride)
1255 result.addOperands({stride, elementsPerStride});
1256}
1257
1258void DmaStartOp::print(OpAsmPrinter &p) {
1259 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1260 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1261 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1262 if (isStrided())
1263 p << ", " << getStride() << ", " << getNumElementsPerStride();
1264
1265 p.printOptionalAttrDict((*this)->getAttrs());
1266 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1267 << ", " << getTagMemRef().getType();
1268}
1269
1270// Parse DmaStartOp.
1271// Ex:
1272// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1273// %tag[%index], %stride, %num_elt_per_stride :
1274// : memref<3076 x f32, 0>,
1275// memref<1024 x f32, 2>,
1276// memref<1 x i32>
1277//
1278ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1279 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1280 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1281 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1283 OpAsmParser::UnresolvedOperand numElementsInfo;
1284 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1285 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1286 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1287
1288 SmallVector<Type, 3> types;
1289 auto indexType = parser.getBuilder().getIndexType();
1290
1291 // Parse and resolve the following list of operands:
1292 // *) source memref followed by its indices (in square brackets).
1293 // *) destination memref followed by its indices (in square brackets).
1294 // *) dma size in KiB.
1295 if (parser.parseOperand(srcMemRefInfo) ||
1296 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1297 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1298 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1299 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1300 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1301 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1302 return failure();
1303
1304 // Parse optional stride and elements per stride.
1305 if (parser.parseTrailingOperandList(strideInfo))
1306 return failure();
1307
1308 bool isStrided = strideInfo.size() == 2;
1309 if (!strideInfo.empty() && !isStrided) {
1310 return parser.emitError(parser.getNameLoc(),
1311 "expected two stride related operands");
1312 }
1313
1314 if (parser.parseColonTypeList(types))
1315 return failure();
1316 if (types.size() != 3)
1317 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1318
1319 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1320 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1321 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1322 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1323 // size should be an index.
1324 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1325 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1326 // tag indices should be index.
1327 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1328 return failure();
1329
1330 if (isStrided) {
1331 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1332 return failure();
1333 }
1334
1335 return success();
1336}
1337
1338LogicalResult DmaStartOp::verify() {
1339 unsigned numOperands = getNumOperands();
1340
1341 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1342 // the number of elements.
1343 if (numOperands < 4)
1344 return emitOpError("expected at least 4 operands");
1345
1346 // Check types of operands. The order of these calls is important: the later
1347 // calls rely on some type properties to compute the operand position.
1348 // 1. Source memref.
1349 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1350 return emitOpError("expected source to be of memref type");
1351 if (numOperands < getSrcMemRefRank() + 4)
1352 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1353 << " operands";
1354 if (!getSrcIndices().empty() &&
1355 !llvm::all_of(getSrcIndices().getTypes(),
1356 [](Type t) { return t.isIndex(); }))
1357 return emitOpError("expected source indices to be of index type");
1358
1359 // 2. Destination memref.
1360 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1361 return emitOpError("expected destination to be of memref type");
1362 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1363 if (numOperands < numExpectedOperands)
1364 return emitOpError() << "expected at least " << numExpectedOperands
1365 << " operands";
1366 if (!getDstIndices().empty() &&
1367 !llvm::all_of(getDstIndices().getTypes(),
1368 [](Type t) { return t.isIndex(); }))
1369 return emitOpError("expected destination indices to be of index type");
1370
1371 // 3. Number of elements.
1372 if (!getNumElements().getType().isIndex())
1373 return emitOpError("expected num elements to be of index type");
1374
1375 // 4. Tag memref.
1376 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1377 return emitOpError("expected tag to be of memref type");
1378 numExpectedOperands += getTagMemRefRank();
1379 if (numOperands < numExpectedOperands)
1380 return emitOpError() << "expected at least " << numExpectedOperands
1381 << " operands";
1382 if (!getTagIndices().empty() &&
1383 !llvm::all_of(getTagIndices().getTypes(),
1384 [](Type t) { return t.isIndex(); }))
1385 return emitOpError("expected tag indices to be of index type");
1386
1387 // Optional stride-related operands must be either both present or both
1388 // absent.
1389 if (numOperands != numExpectedOperands &&
1390 numOperands != numExpectedOperands + 2)
1391 return emitOpError("incorrect number of operands");
1392
1393 // 5. Strides.
1394 if (isStrided()) {
1395 if (!getStride().getType().isIndex() ||
1396 !getNumElementsPerStride().getType().isIndex())
1397 return emitOpError(
1398 "expected stride and num elements per stride to be of type index");
1399 }
1400
1401 return success();
1402}
1403
1404LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1405 SmallVectorImpl<OpFoldResult> &results) {
1406 /// dma_start(memrefcast) -> dma_start
1407 return foldMemRefCast(*this);
1408}
1409
1410void DmaStartOp::setMemrefsAndIndices(RewriterBase &rewriter, Value newSrc,
1411 ValueRange newSrcIndices, Value newDst,
1412 ValueRange newDstIndices) {
1413 /// dma_start has special handling for variadic rank
1414 SmallVector<Value> newOperands;
1415 newOperands.push_back(newSrc);
1416 llvm::append_range(newOperands, newSrcIndices);
1417 newOperands.push_back(newDst);
1418 llvm::append_range(newOperands, newDstIndices);
1419 newOperands.push_back(getNumElements());
1420 newOperands.push_back(getTagMemRef());
1421 llvm::append_range(newOperands, getTagIndices());
1422 if (isStrided()) {
1423 newOperands.push_back(getStride());
1424 newOperands.push_back(getNumElementsPerStride());
1425 }
1426
1427 rewriter.modifyOpInPlace(*this, [&]() { (*this)->setOperands(newOperands); });
1428}
1429
1430// ---------------------------------------------------------------------------
1431// DmaWaitOp
1432// ---------------------------------------------------------------------------
1433
1434LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1435 SmallVectorImpl<OpFoldResult> &results) {
1436 /// dma_wait(memrefcast) -> dma_wait
1437 return foldMemRefCast(*this);
1438}
1439
1440LogicalResult DmaWaitOp::verify() {
1441 // Check that the number of tag indices matches the tagMemRef rank.
1442 unsigned numTagIndices = getTagIndices().size();
1443 unsigned tagMemRefRank = getTagMemRefRank();
1444 if (numTagIndices != tagMemRefRank)
1445 return emitOpError() << "expected tagIndices to have the same number of "
1446 "elements as the tagMemRef rank, expected "
1447 << tagMemRefRank << ", but got " << numTagIndices;
1448 return success();
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// ExtractAlignedPointerAsIndexOp
1453//===----------------------------------------------------------------------===//
1454
1455void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1456 function_ref<void(Value, StringRef)> setNameFn) {
1457 setNameFn(getResult(), "intptr");
1458}
1459
1460//===----------------------------------------------------------------------===//
1461// ExtractStridedMetadataOp
1462//===----------------------------------------------------------------------===//
1463
1464/// The number and type of the results are inferred from the
1465/// shape of the source.
1466LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1467 MLIRContext *context, std::optional<Location> location,
1468 ExtractStridedMetadataOp::Adaptor adaptor,
1469 SmallVectorImpl<Type> &inferredReturnTypes) {
1470 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1471 if (!sourceType)
1472 return failure();
1473
1474 unsigned sourceRank = sourceType.getRank();
1475 IndexType indexType = IndexType::get(context);
1476 auto memrefType =
1477 MemRefType::get({}, sourceType.getElementType(),
1478 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1479 // Base.
1480 inferredReturnTypes.push_back(memrefType);
1481 // Offset.
1482 inferredReturnTypes.push_back(indexType);
1483 // Sizes and strides.
1484 for (unsigned i = 0; i < sourceRank * 2; ++i)
1485 inferredReturnTypes.push_back(indexType);
1486 return success();
1487}
1488
1489void ExtractStridedMetadataOp::getAsmResultNames(
1490 function_ref<void(Value, StringRef)> setNameFn) {
1491 setNameFn(getBaseBuffer(), "base_buffer");
1492 setNameFn(getOffset(), "offset");
1493 // For multi-result to work properly with pretty names and packed syntax `x:3`
1494 // we can only give a pretty name to the first value in the pack.
1495 if (!getSizes().empty()) {
1496 setNameFn(getSizes().front(), "sizes");
1497 setNameFn(getStrides().front(), "strides");
1498 }
1499}
1500
1501/// Helper function to perform the replacement of all constant uses of `values`
1502/// by a materialized constant extracted from `maybeConstants`.
1503/// `values` and `maybeConstants` are expected to have the same size.
1504template <typename Container>
1505static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1506 Container values,
1507 ArrayRef<OpFoldResult> maybeConstants) {
1508 assert(values.size() == maybeConstants.size() &&
1509 " expected values and maybeConstants of the same size");
1510 bool atLeastOneReplacement = false;
1511 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1512 // Don't materialize a constant if there are no uses: this would indice
1513 // infinite loops in the driver.
1514 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1515 continue;
1516 assert(isa<Attribute>(maybeConstant) &&
1517 "The constified value should be either unchanged (i.e., == result) "
1518 "or a constant");
1520 rewriter, loc,
1521 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1522 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1523 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1524 // yet.
1525 op->replaceUsesOfWith(result, constantVal);
1526 atLeastOneReplacement = true;
1527 }
1528 }
1529 return atLeastOneReplacement;
1530}
1531
1532LogicalResult
1533ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1534 SmallVectorImpl<OpFoldResult> &results) {
1535 OpBuilder builder(*this);
1536
1537 bool atLeastOneReplacement = replaceConstantUsesOf(
1538 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1539 getConstifiedMixedOffset());
1540 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1541 getConstifiedMixedSizes());
1542 atLeastOneReplacement |= replaceConstantUsesOf(
1543 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1544
1545 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1546 if (auto prev = getSource().getDefiningOp<CastOp>())
1547 if (isa<MemRefType>(prev.getSource().getType())) {
1548 getSourceMutable().assign(prev.getSource());
1549 atLeastOneReplacement = true;
1550 }
1551
1552 return success(atLeastOneReplacement);
1553}
1554
1555SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1556 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1557 constifyIndexValues(values, getSource().getType().getShape());
1558 return values;
1559}
1560
1561SmallVector<OpFoldResult>
1562ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1563 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1564 SmallVector<int64_t> staticValues;
1565 int64_t unused;
1566 LogicalResult status =
1567 getSource().getType().getStridesAndOffset(staticValues, unused);
1568 (void)status;
1569 assert(succeeded(status) && "could not get strides from type");
1570 constifyIndexValues(values, staticValues);
1571 return values;
1572}
1573
1574OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1575 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1576 SmallVector<OpFoldResult> values(1, offsetOfr);
1577 SmallVector<int64_t> staticValues, unused;
1578 int64_t offset;
1579 LogicalResult status =
1580 getSource().getType().getStridesAndOffset(unused, offset);
1581 (void)status;
1582 assert(succeeded(status) && "could not get offset from type");
1583 staticValues.push_back(offset);
1584 constifyIndexValues(values, staticValues);
1585 return values[0];
1586}
1587
1588//===----------------------------------------------------------------------===//
1589// GenericAtomicRMWOp
1590//===----------------------------------------------------------------------===//
1591
1592void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1593 Value memref, ValueRange ivs) {
1594 OpBuilder::InsertionGuard g(builder);
1595 result.addOperands(memref);
1596 result.addOperands(ivs);
1597
1598 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1599 Type elementType = memrefType.getElementType();
1600 result.addTypes(elementType);
1601
1602 Region *bodyRegion = result.addRegion();
1603 builder.createBlock(bodyRegion);
1604 bodyRegion->addArgument(elementType, memref.getLoc());
1605 }
1606}
1607
1608LogicalResult GenericAtomicRMWOp::verify() {
1609 auto &body = getRegion();
1610 if (body.getNumArguments() != 1)
1611 return emitOpError("expected single number of entry block arguments");
1612
1613 if (getResult().getType() != body.getArgument(0).getType())
1614 return emitOpError("expected block argument of the same type result type");
1615
1616 bool hasSideEffects =
1617 body.walk([&](Operation *nestedOp) {
1618 if (isMemoryEffectFree(nestedOp))
1619 return WalkResult::advance();
1620 nestedOp->emitError(
1621 "body of 'memref.generic_atomic_rmw' should contain "
1622 "only operations with no side effects");
1623 return WalkResult::interrupt();
1624 })
1625 .wasInterrupted();
1626 return hasSideEffects ? failure() : success();
1627}
1628
1629ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1630 OperationState &result) {
1631 OpAsmParser::UnresolvedOperand memref;
1632 Type memrefType;
1633 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1634
1635 Type indexType = parser.getBuilder().getIndexType();
1636 if (parser.parseOperand(memref) ||
1638 parser.parseColonType(memrefType) ||
1639 parser.resolveOperand(memref, memrefType, result.operands) ||
1640 parser.resolveOperands(ivs, indexType, result.operands))
1641 return failure();
1642
1643 Region *body = result.addRegion();
1644 if (parser.parseRegion(*body, {}) ||
1645 parser.parseOptionalAttrDict(result.attributes))
1646 return failure();
1647 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1648 return success();
1649}
1650
1651void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1652 p << ' ' << getMemref() << "[" << getIndices()
1653 << "] : " << getMemref().getType() << ' ';
1654 p.printRegion(getRegion());
1655 p.printOptionalAttrDict((*this)->getAttrs());
1656}
1657
1658TypedValue<MemRefType> GenericAtomicRMWOp::getAccessedMemref() {
1659 return getMemref();
1660}
1661
1662std::optional<SmallVector<Value>> GenericAtomicRMWOp::updateMemrefAndIndices(
1663 RewriterBase &rewriter, Value newMemref, ValueRange newIndices) {
1664 rewriter.modifyOpInPlace(*this, [&]() {
1665 getMemrefMutable().assign(newMemref);
1666 getIndicesMutable().assign(newIndices);
1667 });
1668 return std::nullopt;
1669}
1670
1671//===----------------------------------------------------------------------===//
1672// AtomicYieldOp
1673//===----------------------------------------------------------------------===//
1674
1675LogicalResult AtomicYieldOp::verify() {
1676 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1677 Type resultType = getResult().getType();
1678 if (parentType != resultType)
1679 return emitOpError() << "types mismatch between yield op: " << resultType
1680 << " and its parent: " << parentType;
1681 return success();
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// GlobalOp
1686//===----------------------------------------------------------------------===//
1687
1689 TypeAttr type,
1690 Attribute initialValue) {
1691 p << type;
1692 if (!op.isExternal()) {
1693 p << " = ";
1694 if (op.isUninitialized())
1695 p << "uninitialized";
1696 else
1697 p.printAttributeWithoutType(initialValue);
1698 }
1699}
1700
1701static ParseResult
1703 Attribute &initialValue) {
1704 Type type;
1705 if (parser.parseType(type))
1706 return failure();
1707
1708 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1709 if (!memrefType || !memrefType.hasStaticShape())
1710 return parser.emitError(parser.getNameLoc())
1711 << "type should be static shaped memref, but got " << type;
1712 typeAttr = TypeAttr::get(type);
1713
1714 if (parser.parseOptionalEqual())
1715 return success();
1716
1717 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1718 initialValue = UnitAttr::get(parser.getContext());
1719 return success();
1720 }
1721
1722 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1723 if (parser.parseAttribute(initialValue, tensorType))
1724 return failure();
1725 if (!llvm::isa<ElementsAttr>(initialValue))
1726 return parser.emitError(parser.getNameLoc())
1727 << "initial value should be a unit or elements attribute";
1728 return success();
1729}
1730
1731LogicalResult GlobalOp::verify() {
1732 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1733 if (!memrefType || !memrefType.hasStaticShape())
1734 return emitOpError("type should be static shaped memref, but got ")
1735 << getType();
1736
1737 // Verify that the initial value, if present, is either a unit attribute or
1738 // an elements attribute.
1739 if (getInitialValue().has_value()) {
1740 Attribute initValue = getInitialValue().value();
1741 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1742 return emitOpError("initial value should be a unit or elements "
1743 "attribute, but got ")
1744 << initValue;
1745
1746 // Check that the type of the initial value is compatible with the type of
1747 // the global variable.
1748 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1749 // Check the element types match.
1750 auto initElementType =
1751 cast<TensorType>(elementsAttr.getType()).getElementType();
1752 auto memrefElementType = memrefType.getElementType();
1753
1754 if (initElementType != memrefElementType)
1755 return emitOpError("initial value element expected to be of type ")
1756 << memrefElementType << ", but was of type " << initElementType;
1757
1758 // Check the shapes match, given that memref globals can only produce
1759 // statically shaped memrefs and elements literal type must have a static
1760 // shape we can assume both types are shaped.
1761 auto initShape = elementsAttr.getShapedType().getShape();
1762 auto memrefShape = memrefType.getShape();
1763 if (initShape != memrefShape)
1764 return emitOpError("initial value shape expected to be ")
1765 << memrefShape << " but was " << initShape;
1766 }
1767 }
1768
1769 // TODO: verify visibility for declarations.
1770 return success();
1771}
1772
1773ElementsAttr GlobalOp::getConstantInitValue() {
1774 auto initVal = getInitialValue();
1775 if (getConstant() && initVal.has_value())
1776 return llvm::cast<ElementsAttr>(initVal.value());
1777 return {};
1778}
1779
1780//===----------------------------------------------------------------------===//
1781// GetGlobalOp
1782//===----------------------------------------------------------------------===//
1783
1784LogicalResult
1785GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1786 // Verify that the result type is same as the type of the referenced
1787 // memref.global op.
1788 auto global =
1789 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1790 if (!global)
1791 return emitOpError("'")
1792 << getName() << "' does not reference a valid global memref";
1793
1794 Type resultType = getResult().getType();
1795 if (global.getType() != resultType)
1796 return emitOpError("result type ")
1797 << resultType << " does not match type " << global.getType()
1798 << " of the global memref @" << getName();
1799 return success();
1800}
1801
1802//===----------------------------------------------------------------------===//
1803// LoadOp
1804//===----------------------------------------------------------------------===//
1805
1806OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1807 /// load(memrefcast) -> load
1808 if (succeeded(foldMemRefCast(*this)))
1809 return getResult();
1810
1811 // Fold load from a global constant memref.
1812 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1813 if (!getGlobalOp)
1814 return {};
1815
1816 // Get to the memref.global defining the symbol.
1818 getGlobalOp, getGlobalOp.getNameAttr());
1819 if (!global)
1820 return {};
1821 // If it's a splat constant, we can fold irrespective of indices.
1822 auto splatAttr =
1823 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1824 if (!splatAttr)
1825 return {};
1826
1827 return splatAttr.getSplatValue<Attribute>();
1828}
1829
1830TypedValue<MemRefType> LoadOp::getAccessedMemref() { return getMemref(); }
1831
1832std::optional<SmallVector<Value>>
1833LoadOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1834 ValueRange newIndices) {
1835 rewriter.modifyOpInPlace(*this, [&]() {
1836 getMemrefMutable().assign(newMemref);
1837 getIndicesMutable().assign(newIndices);
1838 });
1839 return std::nullopt;
1840}
1841
1842FailureOr<std::optional<SmallVector<Value>>>
1843LoadOp::bubbleDownCasts(OpBuilder &builder) {
1845 getResult());
1846}
1847
1848//===----------------------------------------------------------------------===//
1849// MemorySpaceCastOp
1850//===----------------------------------------------------------------------===//
1851
1852void MemorySpaceCastOp::getAsmResultNames(
1853 function_ref<void(Value, StringRef)> setNameFn) {
1854 setNameFn(getResult(), "memspacecast");
1855}
1856
1857bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1858 if (inputs.size() != 1 || outputs.size() != 1)
1859 return false;
1860 Type a = inputs.front(), b = outputs.front();
1861 auto aT = llvm::dyn_cast<MemRefType>(a);
1862 auto bT = llvm::dyn_cast<MemRefType>(b);
1863
1864 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1865 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1866
1867 if (aT && bT) {
1868 if (aT.getElementType() != bT.getElementType())
1869 return false;
1870 if (aT.getLayout() != bT.getLayout())
1871 return false;
1872 if (aT.getShape() != bT.getShape())
1873 return false;
1874 return true;
1875 }
1876 if (uaT && ubT) {
1877 return uaT.getElementType() == ubT.getElementType();
1878 }
1879 return false;
1880}
1881
1882OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1883 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1884 // t2)
1885 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1886 getSourceMutable().assign(parentCast.getSource());
1887 return getResult();
1888 }
1889 return Value{};
1890}
1891
1892TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1893 return getSource();
1894}
1895
1896TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1897 return getDest();
1898}
1899
1900bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1901 PtrLikeTypeInterface src) {
1902 return isa<BaseMemRefType>(tgt) &&
1903 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1904}
1905
1906MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1907 OpBuilder &b, PtrLikeTypeInterface tgt,
1909 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1910 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1911}
1912
1913/// The only cast we recognize as promotable is to the generic space.
1914bool MemorySpaceCastOp::isSourcePromotable() {
1915 return getDest().getType().getMemorySpace() == nullptr;
1916}
1917
1918//===----------------------------------------------------------------------===//
1919// PrefetchOp
1920//===----------------------------------------------------------------------===//
1921
1922void PrefetchOp::print(OpAsmPrinter &p) {
1923 p << " " << getMemref() << '[';
1925 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1926 p << ", locality<" << getLocalityHint();
1927 p << ">, " << (getIsDataCache() ? "data" : "instr");
1929 (*this)->getAttrs(),
1930 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1931 p << " : " << getMemRefType();
1932}
1933
1934ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1935 OpAsmParser::UnresolvedOperand memrefInfo;
1936 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1937 IntegerAttr localityHint;
1938 MemRefType type;
1939 StringRef readOrWrite, cacheType;
1940
1941 auto indexTy = parser.getBuilder().getIndexType();
1942 auto i32Type = parser.getBuilder().getIntegerType(32);
1943 if (parser.parseOperand(memrefInfo) ||
1945 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1946 parser.parseComma() || parser.parseKeyword("locality") ||
1947 parser.parseLess() ||
1948 parser.parseAttribute(localityHint, i32Type, "localityHint",
1949 result.attributes) ||
1950 parser.parseGreater() || parser.parseComma() ||
1951 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1952 parser.resolveOperand(memrefInfo, type, result.operands) ||
1953 parser.resolveOperands(indexInfo, indexTy, result.operands))
1954 return failure();
1955
1956 if (readOrWrite != "read" && readOrWrite != "write")
1957 return parser.emitError(parser.getNameLoc(),
1958 "rw specifier has to be 'read' or 'write'");
1959 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1960 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1961
1962 if (cacheType != "data" && cacheType != "instr")
1963 return parser.emitError(parser.getNameLoc(),
1964 "cache type has to be 'data' or 'instr'");
1965
1966 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1967 parser.getBuilder().getBoolAttr(cacheType == "data"));
1968
1969 return success();
1970}
1971
1972LogicalResult PrefetchOp::verify() {
1973 if (getNumOperands() != 1 + getMemRefType().getRank())
1974 return emitOpError("too few indices");
1975
1976 return success();
1977}
1978
1979LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1980 SmallVectorImpl<OpFoldResult> &results) {
1981 // prefetch(memrefcast) -> prefetch
1982 return foldMemRefCast(*this);
1983}
1984
1985TypedValue<MemRefType> PrefetchOp::getAccessedMemref() { return getMemref(); }
1986
1987std::optional<SmallVector<Value>>
1988PrefetchOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
1989 ValueRange newIndices) {
1990 rewriter.modifyOpInPlace(*this, [&]() {
1991 getMemrefMutable().assign(newMemref);
1992 getIndicesMutable().assign(newIndices);
1993 });
1994 return std::nullopt;
1995}
1996
1997//===----------------------------------------------------------------------===//
1998// RankOp
1999//===----------------------------------------------------------------------===//
2000
2001OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
2002 // Constant fold rank when the rank of the operand is known.
2003 auto type = getOperand().getType();
2004 auto shapedType = llvm::dyn_cast<ShapedType>(type);
2005 if (shapedType && shapedType.hasRank())
2006 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
2007 return IntegerAttr();
2008}
2009
2010//===----------------------------------------------------------------------===//
2011// ReinterpretCastOp
2012//===----------------------------------------------------------------------===//
2013
2014void ReinterpretCastOp::getAsmResultNames(
2015 function_ref<void(Value, StringRef)> setNameFn) {
2016 setNameFn(getResult(), "reinterpret_cast");
2017}
2018
2019/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
2020/// `staticSizes` and `staticStrides` are automatically filled with
2021/// source-memref-rank sentinel values that encode dynamic entries.
2022void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2023 MemRefType resultType, Value source,
2024 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
2025 ArrayRef<OpFoldResult> strides,
2026 ArrayRef<NamedAttribute> attrs) {
2027 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2028 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2029 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
2030 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2031 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2032 result.addAttributes(attrs);
2033 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2034 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2035 b.getDenseI64ArrayAttr(staticSizes),
2036 b.getDenseI64ArrayAttr(staticStrides));
2037}
2038
2039void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2040 Value source, OpFoldResult offset,
2041 ArrayRef<OpFoldResult> sizes,
2042 ArrayRef<OpFoldResult> strides,
2043 ArrayRef<NamedAttribute> attrs) {
2044 auto sourceType = cast<BaseMemRefType>(source.getType());
2045 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2046 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2047 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
2048 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2049 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2050 auto stridedLayout = StridedLayoutAttr::get(
2051 b.getContext(), staticOffsets.front(), staticStrides);
2052 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2053 stridedLayout, sourceType.getMemorySpace());
2054 build(b, result, resultType, source, offset, sizes, strides, attrs);
2055}
2056
2057void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2058 MemRefType resultType, Value source,
2059 int64_t offset, ArrayRef<int64_t> sizes,
2060 ArrayRef<int64_t> strides,
2061 ArrayRef<NamedAttribute> attrs) {
2062 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2063 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
2064 SmallVector<OpFoldResult> strideValues =
2065 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2066 return b.getI64IntegerAttr(v);
2067 });
2068 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
2069 strideValues, attrs);
2070}
2071
2072void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2073 MemRefType resultType, Value source, Value offset,
2074 ValueRange sizes, ValueRange strides,
2075 ArrayRef<NamedAttribute> attrs) {
2076 SmallVector<OpFoldResult> sizeValues =
2077 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2078 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2079 strides, [](Value v) -> OpFoldResult { return v; });
2080 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
2081}
2082
2083// TODO: ponder whether we want to allow missing trailing sizes/strides that are
2084// completed automatically, like we have for subview and extract_slice.
2085LogicalResult ReinterpretCastOp::verify() {
2086 // The source and result memrefs should be in the same memory space.
2087 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
2088 auto resultType = llvm::cast<MemRefType>(getType());
2089 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2090 return emitError("different memory spaces specified for source type ")
2091 << srcType << " and result memref type " << resultType;
2092 if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
2093 "result")))
2094 return failure();
2095
2096 // Match sizes in result memref type and in static_sizes attribute.
2097 for (auto [idx, resultSize, expectedSize] :
2098 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2099 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2100 return emitError("expected result type with size = ")
2101 << (ShapedType::isDynamic(expectedSize)
2102 ? std::string("dynamic")
2103 : std::to_string(expectedSize))
2104 << " instead of " << resultSize << " in dim = " << idx;
2105 }
2106
2107 // Match offset and strides in static_offset and static_strides attributes. If
2108 // result memref type has no affine map specified, this will assume an
2109 // identity layout.
2110 int64_t resultOffset;
2111 SmallVector<int64_t, 4> resultStrides;
2112 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2113 return emitError("expected result type to have strided layout but found ")
2114 << resultType;
2115
2116 // Match offset in result memref type and in static_offsets attribute.
2117 int64_t expectedOffset = getStaticOffsets().front();
2118 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2119 return emitError("expected result type with offset = ")
2120 << (ShapedType::isDynamic(expectedOffset)
2121 ? std::string("dynamic")
2122 : std::to_string(expectedOffset))
2123 << " instead of " << resultOffset;
2124
2125 // Match strides in result memref type and in static_strides attribute.
2126 for (auto [idx, resultStride, expectedStride] :
2127 llvm::enumerate(resultStrides, getStaticStrides())) {
2128 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2129 return emitError("expected result type with stride = ")
2130 << (ShapedType::isDynamic(expectedStride)
2131 ? std::string("dynamic")
2132 : std::to_string(expectedStride))
2133 << " instead of " << resultStride << " in dim = " << idx;
2134 }
2135
2136 return success();
2137}
2138
2139OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
2140 Value src = getSource();
2141 auto getPrevSrc = [&]() -> Value {
2142 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
2143 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
2144 return prev.getSource();
2145
2146 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
2147 if (auto prev = src.getDefiningOp<CastOp>())
2148 return prev.getSource();
2149
2150 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2151 // are 0.
2152 if (auto prev = src.getDefiningOp<SubViewOp>())
2153 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2154 return prev.getSource();
2155
2156 return nullptr;
2157 };
2158
2159 if (auto prevSrc = getPrevSrc()) {
2160 getSourceMutable().assign(prevSrc);
2161 return getResult();
2162 }
2163
2164 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2165 if (ShapedType::isStaticShape(getType().getShape()) &&
2166 src.getType() == getType() && getStaticOffsets().front() == 0) {
2167 return src;
2168 }
2169
2170 return nullptr;
2171}
2172
2173SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2174 SmallVector<OpFoldResult> values = getMixedSizes();
2176 return values;
2177}
2178
2179SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2180 SmallVector<OpFoldResult> values = getMixedStrides();
2181 SmallVector<int64_t> staticValues;
2182 int64_t unused;
2183 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2184 (void)status;
2185 assert(succeeded(status) && "could not get strides from type");
2186 constifyIndexValues(values, staticValues);
2187 return values;
2188}
2189
2190OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2191 SmallVector<OpFoldResult> values = getMixedOffsets();
2192 assert(values.size() == 1 &&
2193 "reinterpret_cast must have one and only one offset");
2194 SmallVector<int64_t> staticValues, unused;
2195 int64_t offset;
2196 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2197 (void)status;
2198 assert(succeeded(status) && "could not get offset from type");
2199 staticValues.push_back(offset);
2200 constifyIndexValues(values, staticValues);
2201 return values[0];
2202}
2203
2204namespace {
2205/// Replace the sequence:
2206/// ```
2207/// base, offset, sizes, strides = extract_strided_metadata src
2208/// dst = reinterpret_cast base to offset, sizes, strides
2209/// ```
2210/// With
2211///
2212/// ```
2213/// dst = memref.cast src
2214/// ```
2215///
2216/// Note: The cast operation is only inserted when the type of dst and src
2217/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2218///
2219/// This pattern also matches when the offset, sizes, and strides don't come
2220/// directly from the `extract_strided_metadata`'s results but it can be
2221/// statically proven that they would hold the same values.
2222///
2223/// For instance, the following sequence would be replaced:
2224/// ```
2225/// base, offset, sizes, strides =
2226/// extract_strided_metadata memref : memref<3x4xty>
2227/// dst = reinterpret_cast base to 0, [3, 4], strides
2228/// ```
2229/// Because we know (thanks to the type of the input memref) that variable
2230/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2231///
2232/// Similarly, the following sequence would be replaced:
2233/// ```
2234/// c0 = arith.constant 0
2235/// c4 = arith.constant 4
2236/// base, offset, sizes, strides =
2237/// extract_strided_metadata memref : memref<3x4xty>
2238/// dst = reinterpret_cast base to c0, [3, c4], strides
2239/// ```
2240/// Because we know that `offset`and `c0` will hold 0
2241/// and `c4` will hold 4.
2242///
2243/// If the pattern above does not match, the input of the
2244/// extract_strided_metadata is always folded into the input of the
2245/// reinterpret_cast operator. This allows for dead code elimination to get rid
2246/// of the extract_strided_metadata in some cases.
2247struct ReinterpretCastOpExtractStridedMetadataFolder
2248 : public OpRewritePattern<ReinterpretCastOp> {
2249public:
2250 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2251
2252 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2253 PatternRewriter &rewriter) const override {
2254 auto extractStridedMetadata =
2255 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2256 if (!extractStridedMetadata)
2257 return failure();
2258
2259 // Check if the reinterpret cast reconstructs a memref with the exact same
2260 // properties as the extract strided metadata.
2261 auto isReinterpretCastNoop = [&]() -> bool {
2262 // First, check that the strides are the same.
2263 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2264 op.getConstifiedMixedStrides()))
2265 return false;
2266
2267 // Second, check the sizes.
2268 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2269 op.getConstifiedMixedSizes()))
2270 return false;
2271
2272 // Finally, check the offset.
2273 assert(op.getMixedOffsets().size() == 1 &&
2274 "reinterpret_cast with more than one offset should have been "
2275 "rejected by the verifier");
2276 return extractStridedMetadata.getConstifiedMixedOffset() ==
2277 op.getConstifiedMixedOffset();
2278 };
2279
2280 if (!isReinterpretCastNoop()) {
2281 // If the extract_strided_metadata / reinterpret_cast pair can't be
2282 // completely folded, then we could fold the input of the
2283 // extract_strided_metadata into the input of the reinterpret_cast
2284 // input. For some cases (e.g., static dimensions) the
2285 // the extract_strided_metadata is eliminated by dead code elimination.
2286 //
2287 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2288 //
2289 // We can always fold the input of a extract_strided_metadata operator
2290 // to the input of a reinterpret_cast operator, because they point to
2291 // the same memory. Note that the reinterpret_cast does not use the
2292 // layout of its input memref, only its base memory pointer which is
2293 // the same as the base pointer returned by the extract_strided_metadata
2294 // operator and the base pointer of the extract_strided_metadata memref
2295 // input.
2296 rewriter.modifyOpInPlace(op, [&]() {
2297 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2298 });
2299 return success();
2300 }
2301
2302 // At this point, we know that the back and forth between extract strided
2303 // metadata and reinterpret cast is a noop. However, the final type of the
2304 // reinterpret cast may not be exactly the same as the original memref.
2305 // E.g., it could be changing a dimension from static to dynamic. Check that
2306 // here and add a cast if necessary.
2307 Type srcTy = extractStridedMetadata.getSource().getType();
2308 if (srcTy == op.getResult().getType())
2309 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2310 else
2311 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2312 extractStridedMetadata.getSource());
2313
2314 return success();
2315 }
2316};
2317
2318struct ReinterpretCastOpConstantFolder
2319 : public OpRewritePattern<ReinterpretCastOp> {
2320public:
2321 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2322
2323 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2324 PatternRewriter &rewriter) const override {
2325 unsigned srcStaticCount = llvm::count_if(
2326 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2327 op.getMixedStrides()),
2328 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2329
2330 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2331 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2332 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2333
2334 // If the offset is a negative constant, we can't fold it because the
2335 // resulting memref type would be invalid. In that case, we keep the
2336 // original offset.
2337 if (auto cst = getConstantIntValue(offsets[0]))
2338 if (*cst < 0)
2339 offsets[0] = op.getMixedOffsets()[0];
2340
2341 // If the size is a negative constant, we can't fold it because the
2342 // resulting memref type would be invalid. In that case, we keep the
2343 // original size.
2344 for (auto it : llvm::zip(op.getMixedSizes(), sizes)) {
2345 auto &srcSizeOfr = std::get<0>(it);
2346 auto &sizeOfr = std::get<1>(it);
2347 if (auto cst = getConstantIntValue(sizeOfr))
2348 if (*cst < 0)
2349 sizeOfr = srcSizeOfr;
2350 }
2351
2352 // TODO: Using counting comparison instead of direct comparison because
2353 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2354 // IntegerAttrs, while constifyIndexValues (and therefore
2355 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2356 if (srcStaticCount ==
2357 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2358 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2359 return failure();
2360
2361 auto newReinterpretCast = ReinterpretCastOp::create(
2362 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2363
2364 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2365 return success();
2366 }
2367};
2368} // namespace
2369
2370void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2371 MLIRContext *context) {
2372 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2373 ReinterpretCastOpConstantFolder>(context);
2374}
2375
2376FailureOr<std::optional<SmallVector<Value>>>
2377ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2378 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2379}
2380
2381//===----------------------------------------------------------------------===//
2382// Reassociative reshape ops
2383//===----------------------------------------------------------------------===//
2384
2385void CollapseShapeOp::getAsmResultNames(
2386 function_ref<void(Value, StringRef)> setNameFn) {
2387 setNameFn(getResult(), "collapse_shape");
2388}
2389
2390void ExpandShapeOp::getAsmResultNames(
2391 function_ref<void(Value, StringRef)> setNameFn) {
2392 setNameFn(getResult(), "expand_shape");
2393}
2394
2395LogicalResult ExpandShapeOp::reifyResultShapes(
2396 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2397 reifiedResultShapes = {
2398 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2399 return success();
2400}
2401
2402/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2403/// result and operand. Layout maps are verified separately.
2404///
2405/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2406/// allowed in a reassocation group.
2407static LogicalResult
2409 ArrayRef<int64_t> expandedShape,
2410 ArrayRef<ReassociationIndices> reassociation,
2411 bool allowMultipleDynamicDimsPerGroup) {
2412 // There must be one reassociation group per collapsed dimension.
2413 if (collapsedShape.size() != reassociation.size())
2414 return op->emitOpError("invalid number of reassociation groups: found ")
2415 << reassociation.size() << ", expected " << collapsedShape.size();
2416
2417 // The next expected expanded dimension index (while iterating over
2418 // reassociation indices).
2419 int64_t nextDim = 0;
2420 for (const auto &it : llvm::enumerate(reassociation)) {
2421 ReassociationIndices group = it.value();
2422 int64_t collapsedDim = it.index();
2423
2424 bool foundDynamic = false;
2425 for (int64_t expandedDim : group) {
2426 if (expandedDim != nextDim++)
2427 return op->emitOpError("reassociation indices must be contiguous");
2428
2429 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2430 return op->emitOpError("reassociation index ")
2431 << expandedDim << " is out of bounds";
2432
2433 // Check if there are multiple dynamic dims in a reassociation group.
2434 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2435 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2436 return op->emitOpError(
2437 "at most one dimension in a reassociation group may be dynamic");
2438 foundDynamic = true;
2439 }
2440 }
2441
2442 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2443 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2444 return op->emitOpError("collapsed dim (")
2445 << collapsedDim
2446 << ") must be dynamic if and only if reassociation group is "
2447 "dynamic";
2448
2449 // If all dims in the reassociation group are static, the size of the
2450 // collapsed dim can be verified.
2451 if (!foundDynamic) {
2452 int64_t groupSize = 1;
2453 for (int64_t expandedDim : group)
2454 groupSize *= expandedShape[expandedDim];
2455 if (groupSize != collapsedShape[collapsedDim])
2456 return op->emitOpError("collapsed dim size (")
2457 << collapsedShape[collapsedDim]
2458 << ") must equal reassociation group size (" << groupSize << ")";
2459 }
2460 }
2461
2462 if (collapsedShape.empty()) {
2463 // Rank 0: All expanded dimensions must be 1.
2464 for (int64_t d : expandedShape)
2465 if (d != 1)
2466 return op->emitOpError(
2467 "rank 0 memrefs can only be extended/collapsed with/from ones");
2468 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2469 // Rank >= 1: Number of dimensions among all reassociation groups must match
2470 // the result memref rank.
2471 return op->emitOpError("expanded rank (")
2472 << expandedShape.size()
2473 << ") inconsistent with number of reassociation indices (" << nextDim
2474 << ")";
2475 }
2476
2477 return success();
2478}
2479
2480SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2481 return getSymbolLessAffineMaps(getReassociationExprs());
2482}
2483
2484SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2486 getReassociationIndices());
2487}
2488
2489SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2490 return getSymbolLessAffineMaps(getReassociationExprs());
2491}
2492
2493SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2495 getReassociationIndices());
2496}
2497
2498/// Compute the layout map after expanding a given source MemRef type with the
2499/// specified reassociation indices.
2500static FailureOr<StridedLayoutAttr>
2501computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2502 ArrayRef<ReassociationIndices> reassociation) {
2503 int64_t srcOffset;
2504 SmallVector<int64_t> srcStrides;
2505 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2506 return failure();
2507 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2508
2509 // 1-1 mapping between srcStrides and reassociation packs.
2510 // Each srcStride starts with the given value and gets expanded according to
2511 // the proper entries in resultShape.
2512 // Example:
2513 // srcStrides = [10000, 1 , 100 ],
2514 // reassociations = [ [0], [1], [2, 3, 4]],
2515 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2516 // -> For the purpose of stride calculation, the useful sizes are:
2517 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2518 // resultStrides = [10000, 1, 600, 200, 100]
2519 // Note that a stride does not get expanded along the first entry of each
2520 // shape pack.
2521 SmallVector<int64_t> reverseResultStrides;
2522 reverseResultStrides.reserve(resultShape.size());
2523 unsigned shapeIndex = resultShape.size() - 1;
2524 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2525 ReassociationIndices reassoc = std::get<0>(it);
2526 int64_t currentStrideToExpand = std::get<1>(it);
2527 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2528 reverseResultStrides.push_back(currentStrideToExpand);
2529 currentStrideToExpand =
2530 (SaturatedInteger::wrap(currentStrideToExpand) *
2531 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2532 .asInteger();
2533 }
2534 }
2535 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2536 resultStrides.resize(resultShape.size(), 1);
2537 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2538}
2539
2540FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2541 MemRefType srcType, ArrayRef<int64_t> resultShape,
2542 ArrayRef<ReassociationIndices> reassociation) {
2543 if (srcType.getLayout().isIdentity()) {
2544 // If the source is contiguous (i.e., no layout map specified), so is the
2545 // result.
2546 MemRefLayoutAttrInterface layout;
2547 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2548 srcType.getMemorySpace());
2549 }
2550
2551 // Source may not be contiguous. Compute the layout map.
2552 FailureOr<StridedLayoutAttr> computedLayout =
2553 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2554 if (failed(computedLayout))
2555 return failure();
2556 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2557 srcType.getMemorySpace());
2558}
2559
2560FailureOr<SmallVector<OpFoldResult>>
2561ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2562 MemRefType expandedType,
2563 ArrayRef<ReassociationIndices> reassociation,
2564 ArrayRef<OpFoldResult> inputShape) {
2565 std::optional<SmallVector<OpFoldResult>> outputShape =
2566 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2567 inputShape);
2568 if (!outputShape)
2569 return failure();
2570 return *outputShape;
2571}
2572
2573void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2574 Type resultType, Value src,
2575 ArrayRef<ReassociationIndices> reassociation,
2576 ArrayRef<OpFoldResult> outputShape) {
2577 auto [staticOutputShape, dynamicOutputShape] =
2578 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2579 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2580 getReassociationIndicesAttribute(builder, reassociation),
2581 dynamicOutputShape, staticOutputShape);
2582}
2583
2584void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2585 Type resultType, Value src,
2586 ArrayRef<ReassociationIndices> reassociation) {
2587 SmallVector<OpFoldResult> inputShape =
2588 getMixedSizes(builder, result.location, src);
2589 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2590 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2591 builder, result.location, memrefResultTy, reassociation, inputShape);
2592 // Failure of this assertion usually indicates presence of multiple
2593 // dynamic dimensions in the same reassociation group.
2594 assert(succeeded(outputShape) && "unable to infer output shape");
2595 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2596}
2597
2598void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2599 ArrayRef<int64_t> resultShape, Value src,
2600 ArrayRef<ReassociationIndices> reassociation) {
2601 // Only ranked memref source values are supported.
2602 auto srcType = llvm::cast<MemRefType>(src.getType());
2603 FailureOr<MemRefType> resultType =
2604 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2605 // Failure of this assertion usually indicates a problem with the source
2606 // type, e.g., could not get strides/offset.
2607 assert(succeeded(resultType) && "could not compute layout");
2608 build(builder, result, *resultType, src, reassociation);
2609}
2610
2611void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2612 ArrayRef<int64_t> resultShape, Value src,
2613 ArrayRef<ReassociationIndices> reassociation,
2614 ArrayRef<OpFoldResult> outputShape) {
2615 // Only ranked memref source values are supported.
2616 auto srcType = llvm::cast<MemRefType>(src.getType());
2617 FailureOr<MemRefType> resultType =
2618 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2619 // Failure of this assertion usually indicates a problem with the source
2620 // type, e.g., could not get strides/offset.
2621 assert(succeeded(resultType) && "could not compute layout");
2622 build(builder, result, *resultType, src, reassociation, outputShape);
2623}
2624
2625LogicalResult ExpandShapeOp::verify() {
2626 MemRefType srcType = getSrcType();
2627 MemRefType resultType = getResultType();
2628
2629 if (srcType.getRank() > resultType.getRank()) {
2630 auto r0 = srcType.getRank();
2631 auto r1 = resultType.getRank();
2632 return emitOpError("has source rank ")
2633 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2634 << r0 << " > " << r1 << ").";
2635 }
2636
2637 // Verify result shape.
2638 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2639 resultType.getShape(),
2640 getReassociationIndices(),
2641 /*allowMultipleDynamicDimsPerGroup=*/true)))
2642 return failure();
2643
2644 // Compute expected result type (including layout map).
2645 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2646 srcType, resultType.getShape(), getReassociationIndices());
2647 if (failed(expectedResultType))
2648 return emitOpError("invalid source layout map");
2649
2650 // Check actual result type.
2651 if (*expectedResultType != resultType)
2652 return emitOpError("expected expanded type to be ")
2653 << *expectedResultType << " but found " << resultType;
2654
2655 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2656 return emitOpError("expected number of static shape bounds to be equal to "
2657 "the output rank (")
2658 << resultType.getRank() << ") but found "
2659 << getStaticOutputShape().size() << " inputs instead";
2660
2661 if ((int64_t)getOutputShape().size() !=
2662 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2663 return emitOpError("mismatch in dynamic dims in output_shape and "
2664 "static_output_shape: static_output_shape has ")
2665 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2666 << " dynamic dims while output_shape has " << getOutputShape().size()
2667 << " values";
2668
2669 // Verify that the number of dynamic dims in output_shape matches the number
2670 // of dynamic dims in the result type.
2671 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2672 getOutputShape())))
2673 return failure();
2674
2675 // Verify if provided output shapes are in agreement with output type.
2676 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2677 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2678 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2679 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2680 return emitOpError("invalid output shape provided at pos ") << pos;
2681 }
2682 }
2683
2684 return success();
2685}
2686
2687struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
2688public:
2689 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2690
2691 LogicalResult matchAndRewrite(ExpandShapeOp op,
2692 PatternRewriter &rewriter) const override {
2693 auto cast = op.getSrc().getDefiningOp<CastOp>();
2694 if (!cast)
2695 return failure();
2696
2697 if (!CastOp::canFoldIntoConsumerOp(cast))
2698 return failure();
2699
2700 SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
2701 SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
2702 SmallVector<int64_t> newOutputShapeSizes;
2703
2704 // Convert output shape dims from dynamic to static where possible.
2705 for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2706 std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
2707 if (!sizeOpt.has_value()) {
2708 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2709 continue;
2710 }
2711
2712 newOutputShapeSizes.push_back(sizeOpt.value());
2713 newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
2714 }
2715
2716 Value castSource = cast.getSource();
2717 auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
2718 SmallVector<ReassociationIndices> reassociationIndices =
2719 op.getReassociationIndices();
2720 for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2721 auto newOutputShapeSizesSlice =
2722 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2723 bool newOutputDynamic =
2724 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2725 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2726 return rewriter.notifyMatchFailure(
2727 op, "folding cast will result in changing dynamicity in "
2728 "reassociation group");
2729 }
2730
2731 FailureOr<MemRefType> newResultTypeOrFailure =
2732 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2733 reassociationIndices);
2734
2735 if (failed(newResultTypeOrFailure))
2736 return rewriter.notifyMatchFailure(
2737 op, "could not compute new expanded type after folding cast");
2738
2739 if (*newResultTypeOrFailure == op.getResultType()) {
2740 rewriter.modifyOpInPlace(
2741 op, [&]() { op.getSrcMutable().assign(castSource); });
2742 } else {
2743 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2744 *newResultTypeOrFailure, castSource,
2745 reassociationIndices, newOutputShape);
2746 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2747 }
2748 return success();
2749 }
2750};
2751
2752void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2753 MLIRContext *context) {
2754 results.add<
2755 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2756 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2757 ExpandShapeOpMemRefCastFolder>(context);
2758}
2759
2760FailureOr<std::optional<SmallVector<Value>>>
2761ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2762 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2763}
2764
2765/// Compute the layout map after collapsing a given source MemRef type with the
2766/// specified reassociation indices.
2767///
2768/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2769/// not possible to check this by inspecting a MemRefType in the general case.
2770/// If non-contiguity cannot be checked statically, the collapse is assumed to
2771/// be valid (and thus accepted by this function) unless `strict = true`.
2772static FailureOr<StridedLayoutAttr>
2773computeCollapsedLayoutMap(MemRefType srcType,
2774 ArrayRef<ReassociationIndices> reassociation,
2775 bool strict = false) {
2776 int64_t srcOffset;
2777 SmallVector<int64_t> srcStrides;
2778 auto srcShape = srcType.getShape();
2779 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2780 return failure();
2781
2782 // The result stride of a reassociation group is the stride of the last entry
2783 // of the reassociation. (TODO: Should be the minimum stride in the
2784 // reassociation because strides are not necessarily sorted. E.g., when using
2785 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2786 // strides are meaningless and could have any arbitrary value.
2787 SmallVector<int64_t> resultStrides;
2788 resultStrides.reserve(reassociation.size());
2789 for (const ReassociationIndices &reassoc : reassociation) {
2790 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2791 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2792 ref = ref.drop_back();
2793 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2794 resultStrides.push_back(srcStrides[ref.back()]);
2795 } else {
2796 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2797 // the corresponding stride may have to be skipped. (See above comment.)
2798 // Therefore, the result stride cannot be statically determined and must
2799 // be dynamic.
2800 resultStrides.push_back(ShapedType::kDynamic);
2801 }
2802 }
2803
2804 // Validate that each reassociation group is contiguous.
2805 unsigned resultStrideIndex = resultStrides.size() - 1;
2806 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2807 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2808 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2809 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2810 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2811
2812 // Dimensions of size 1 should be skipped, because their strides are
2813 // meaningless and could have any arbitrary value.
2814 if (srcShape[idx - 1] == 1)
2815 continue;
2816
2817 // Both source and result stride must have the same static value. In that
2818 // case, we can be sure, that the dimensions are collapsible (because they
2819 // are contiguous).
2820 // If `strict = false` (default during op verification), we accept cases
2821 // where one or both strides are dynamic. This is best effort: We reject
2822 // ops where obviously non-contiguous dims are collapsed, but accept ops
2823 // where we cannot be sure statically. Such ops may fail at runtime. See
2824 // the op documentation for details.
2825 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2826 if (strict && (stride.saturated || srcStride.saturated))
2827 return failure();
2828
2829 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2830 return failure();
2831 }
2832 }
2833 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2834}
2835
2836bool CollapseShapeOp::isGuaranteedCollapsible(
2837 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2838 // MemRefs with identity layout are always collapsible.
2839 if (srcType.getLayout().isIdentity())
2840 return true;
2841
2842 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2843 /*strict=*/true));
2844}
2845
2846MemRefType CollapseShapeOp::computeCollapsedType(
2847 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2848 SmallVector<int64_t> resultShape;
2849 resultShape.reserve(reassociation.size());
2850 for (const ReassociationIndices &group : reassociation) {
2851 auto groupSize = SaturatedInteger::wrap(1);
2852 for (int64_t srcDim : group)
2853 groupSize =
2854 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2855 resultShape.push_back(groupSize.asInteger());
2856 }
2857
2858 if (srcType.getLayout().isIdentity()) {
2859 // If the source is contiguous (i.e., no layout map specified), so is the
2860 // result.
2861 MemRefLayoutAttrInterface layout;
2862 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2863 srcType.getMemorySpace());
2864 }
2865
2866 // Source may not be fully contiguous. Compute the layout map.
2867 // Note: Dimensions that are collapsed into a single dim are assumed to be
2868 // contiguous.
2869 FailureOr<StridedLayoutAttr> computedLayout =
2870 computeCollapsedLayoutMap(srcType, reassociation);
2871 assert(succeeded(computedLayout) &&
2872 "invalid source layout map or collapsing non-contiguous dims");
2873 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2874 srcType.getMemorySpace());
2875}
2876
2877void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2878 ArrayRef<ReassociationIndices> reassociation,
2879 ArrayRef<NamedAttribute> attrs) {
2880 auto srcType = llvm::cast<MemRefType>(src.getType());
2881 MemRefType resultType =
2882 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2884 getReassociationIndicesAttribute(b, reassociation));
2885 build(b, result, resultType, src, attrs);
2886}
2887
2888LogicalResult CollapseShapeOp::verify() {
2889 MemRefType srcType = getSrcType();
2890 MemRefType resultType = getResultType();
2891
2892 if (srcType.getRank() < resultType.getRank()) {
2893 auto r0 = srcType.getRank();
2894 auto r1 = resultType.getRank();
2895 return emitOpError("has source rank ")
2896 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2897 << r0 << " < " << r1 << ").";
2898 }
2899
2900 // Verify result shape.
2901 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2902 srcType.getShape(), getReassociationIndices(),
2903 /*allowMultipleDynamicDimsPerGroup=*/true)))
2904 return failure();
2905
2906 // Compute expected result type (including layout map).
2907 MemRefType expectedResultType;
2908 if (srcType.getLayout().isIdentity()) {
2909 // If the source is contiguous (i.e., no layout map specified), so is the
2910 // result.
2911 MemRefLayoutAttrInterface layout;
2912 expectedResultType =
2913 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2914 srcType.getMemorySpace());
2915 } else {
2916 // Source may not be fully contiguous. Compute the layout map.
2917 // Note: Dimensions that are collapsed into a single dim are assumed to be
2918 // contiguous.
2919 FailureOr<StridedLayoutAttr> computedLayout =
2920 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2921 if (failed(computedLayout))
2922 return emitOpError(
2923 "invalid source layout map or collapsing non-contiguous dims");
2924 expectedResultType =
2925 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2926 *computedLayout, srcType.getMemorySpace());
2927 }
2928
2929 if (expectedResultType != resultType)
2930 return emitOpError("expected collapsed type to be ")
2931 << expectedResultType << " but found " << resultType;
2932
2933 return success();
2934}
2935
2937 : public OpRewritePattern<CollapseShapeOp> {
2938public:
2939 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2940
2941 LogicalResult matchAndRewrite(CollapseShapeOp op,
2942 PatternRewriter &rewriter) const override {
2943 auto cast = op.getOperand().getDefiningOp<CastOp>();
2944 if (!cast)
2945 return failure();
2946
2947 if (!CastOp::canFoldIntoConsumerOp(cast))
2948 return failure();
2949
2950 Type newResultType = CollapseShapeOp::computeCollapsedType(
2951 llvm::cast<MemRefType>(cast.getOperand().getType()),
2952 op.getReassociationIndices());
2953
2954 if (newResultType == op.getResultType()) {
2955 rewriter.modifyOpInPlace(
2956 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2957 } else {
2958 Value newOp =
2959 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2960 op.getReassociationIndices());
2961 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2962 }
2963 return success();
2964 }
2965};
2966
2967void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2968 MLIRContext *context) {
2969 results.add<
2970 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2971 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2972 memref::DimOp, MemRefType>,
2973 CollapseShapeOpMemRefCastFolder>(context);
2974}
2975
2976OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2978 adaptor.getOperands());
2979}
2980
2981OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2983 adaptor.getOperands());
2984}
2985
2986FailureOr<std::optional<SmallVector<Value>>>
2987CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2988 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2989}
2990
2991//===----------------------------------------------------------------------===//
2992// ReshapeOp
2993//===----------------------------------------------------------------------===//
2994
2995void ReshapeOp::getAsmResultNames(
2996 function_ref<void(Value, StringRef)> setNameFn) {
2997 setNameFn(getResult(), "reshape");
2998}
2999
3000LogicalResult ReshapeOp::verify() {
3001 Type operandType = getSource().getType();
3002 Type resultType = getResult().getType();
3003
3004 Type operandElementType =
3005 llvm::cast<ShapedType>(operandType).getElementType();
3006 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
3007 if (operandElementType != resultElementType)
3008 return emitOpError("element types of source and destination memref "
3009 "types should be the same");
3010
3011 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
3012 if (!operandMemRefType.getLayout().isIdentity())
3013 return emitOpError("source memref type should have identity affine map");
3014
3015 int64_t shapeSize =
3016 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
3017 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
3018 if (resultMemRefType) {
3019 if (!resultMemRefType.getLayout().isIdentity())
3020 return emitOpError("result memref type should have identity affine map");
3021 if (shapeSize == ShapedType::kDynamic)
3022 return emitOpError("cannot use shape operand with dynamic length to "
3023 "reshape to statically-ranked memref type");
3024 if (shapeSize != resultMemRefType.getRank())
3025 return emitOpError(
3026 "length of shape operand differs from the result's memref rank");
3027 }
3028 return success();
3029}
3030
3031FailureOr<std::optional<SmallVector<Value>>>
3032ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
3033 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3034}
3035
3036//===----------------------------------------------------------------------===//
3037// StoreOp
3038//===----------------------------------------------------------------------===//
3039
3040LogicalResult StoreOp::fold(FoldAdaptor adaptor,
3041 SmallVectorImpl<OpFoldResult> &results) {
3042 /// store(memrefcast) -> store
3043 return foldMemRefCast(*this, getValueToStore());
3044}
3045
3046TypedValue<MemRefType> StoreOp::getAccessedMemref() { return getMemref(); }
3047
3048std::optional<SmallVector<Value>>
3049StoreOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
3050 ValueRange newIndices) {
3051 rewriter.modifyOpInPlace(*this, [&]() {
3052 getMemrefMutable().assign(newMemref);
3053 getIndicesMutable().assign(newIndices);
3054 });
3055 return std::nullopt;
3056}
3057
3058FailureOr<std::optional<SmallVector<Value>>>
3059StoreOp::bubbleDownCasts(OpBuilder &builder) {
3061 ValueRange());
3062}
3063
3064//===----------------------------------------------------------------------===//
3065// SubViewOp
3066//===----------------------------------------------------------------------===//
3067
3068void SubViewOp::getAsmResultNames(
3069 function_ref<void(Value, StringRef)> setNameFn) {
3070 setNameFn(getResult(), "subview");
3071}
3072
3073/// A subview result type can be fully inferred from the source type and the
3074/// static representation of offsets, sizes and strides. Special sentinels
3075/// encode the dynamic case.
3076MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3077 ArrayRef<int64_t> staticOffsets,
3078 ArrayRef<int64_t> staticSizes,
3079 ArrayRef<int64_t> staticStrides) {
3080 unsigned rank = sourceMemRefType.getRank();
3081 (void)rank;
3082 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
3083 assert(staticSizes.size() == rank && "staticSizes length mismatch");
3084 assert(staticStrides.size() == rank && "staticStrides length mismatch");
3085
3086 // Extract source offset and strides.
3087 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3088
3089 // Compute target offset whose value is:
3090 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
3091 int64_t targetOffset = sourceOffset;
3092 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
3093 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3094 targetOffset = (SaturatedInteger::wrap(targetOffset) +
3095 SaturatedInteger::wrap(staticOffset) *
3096 SaturatedInteger::wrap(sourceStride))
3097 .asInteger();
3098 }
3099
3100 // Compute target stride whose value is:
3101 // `sourceStrides_i * staticStrides_i`.
3102 SmallVector<int64_t, 4> targetStrides;
3103 targetStrides.reserve(staticOffsets.size());
3104 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
3105 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3106 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
3107 SaturatedInteger::wrap(staticStride))
3108 .asInteger());
3109 }
3110
3111 // The type is now known.
3112 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3113 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3114 targetOffset, targetStrides),
3115 sourceMemRefType.getMemorySpace());
3116}
3117
3118MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3119 ArrayRef<OpFoldResult> offsets,
3120 ArrayRef<OpFoldResult> sizes,
3121 ArrayRef<OpFoldResult> strides) {
3122 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3123 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3124 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3125 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3126 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3127 if (!hasValidSizesOffsets(staticOffsets))
3128 return {};
3129 if (!hasValidSizesOffsets(staticSizes))
3130 return {};
3131 if (!hasValidStrides(staticStrides))
3132 return {};
3133 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3134 staticSizes, staticStrides);
3135}
3136
3137MemRefType SubViewOp::inferRankReducedResultType(
3138 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3139 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3140 ArrayRef<int64_t> strides) {
3141 MemRefType inferredType =
3142 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3143 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
3144 "expected ");
3145 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
3146 return inferredType;
3147
3148 // Compute which dimensions are dropped.
3149 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3150 computeRankReductionMask(inferredType.getShape(), resultShape);
3151 assert(dimsToProject.has_value() && "invalid rank reduction");
3152
3153 // Compute the layout and result type.
3154 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3155 SmallVector<int64_t> rankReducedStrides;
3156 rankReducedStrides.reserve(resultShape.size());
3157 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3158 if (!dimsToProject->contains(idx))
3159 rankReducedStrides.push_back(value);
3160 }
3161 return MemRefType::get(resultShape, inferredType.getElementType(),
3162 StridedLayoutAttr::get(inferredLayout.getContext(),
3163 inferredLayout.getOffset(),
3164 rankReducedStrides),
3165 inferredType.getMemorySpace());
3166}
3167
3168MemRefType SubViewOp::inferRankReducedResultType(
3169 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3170 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3171 ArrayRef<OpFoldResult> strides) {
3172 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3173 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3174 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3175 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3176 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3177 return SubViewOp::inferRankReducedResultType(
3178 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3179 staticStrides);
3180}
3181
3182// Build a SubViewOp with mixed static and dynamic entries and custom result
3183// type. If the type passed is nullptr, it is inferred.
3184void SubViewOp::build(OpBuilder &b, OperationState &result,
3185 MemRefType resultType, Value source,
3186 ArrayRef<OpFoldResult> offsets,
3187 ArrayRef<OpFoldResult> sizes,
3188 ArrayRef<OpFoldResult> strides,
3189 ArrayRef<NamedAttribute> attrs) {
3190 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3191 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3192 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3193 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3194 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3195 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
3196 // Structuring implementation this way avoids duplication between builders.
3197 if (!resultType) {
3198 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3199 staticSizes, staticStrides);
3200 }
3201 result.addAttributes(attrs);
3202 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
3203 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3204 b.getDenseI64ArrayAttr(staticSizes),
3205 b.getDenseI64ArrayAttr(staticStrides));
3206}
3207
3208// Build a SubViewOp with mixed static and dynamic entries and inferred result
3209// type.
3210void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3211 ArrayRef<OpFoldResult> offsets,
3212 ArrayRef<OpFoldResult> sizes,
3213 ArrayRef<OpFoldResult> strides,
3214 ArrayRef<NamedAttribute> attrs) {
3215 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3216}
3217
3218// Build a SubViewOp with static entries and inferred result type.
3219void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3220 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3221 ArrayRef<int64_t> strides,
3222 ArrayRef<NamedAttribute> attrs) {
3223 SmallVector<OpFoldResult> offsetValues =
3224 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3225 return b.getI64IntegerAttr(v);
3226 });
3227 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3228 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3229 SmallVector<OpFoldResult> strideValues =
3230 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3231 return b.getI64IntegerAttr(v);
3232 });
3233 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
3234}
3235
3236// Build a SubViewOp with dynamic entries and custom result type. If the
3237// type passed is nullptr, it is inferred.
3238void SubViewOp::build(OpBuilder &b, OperationState &result,
3239 MemRefType resultType, Value source,
3240 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3241 ArrayRef<int64_t> strides,
3242 ArrayRef<NamedAttribute> attrs) {
3243 SmallVector<OpFoldResult> offsetValues =
3244 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3245 return b.getI64IntegerAttr(v);
3246 });
3247 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3248 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3249 SmallVector<OpFoldResult> strideValues =
3250 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3251 return b.getI64IntegerAttr(v);
3252 });
3253 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3254 attrs);
3255}
3256
3257// Build a SubViewOp with dynamic entries and custom result type. If the type
3258// passed is nullptr, it is inferred.
3259void SubViewOp::build(OpBuilder &b, OperationState &result,
3260 MemRefType resultType, Value source, ValueRange offsets,
3261 ValueRange sizes, ValueRange strides,
3262 ArrayRef<NamedAttribute> attrs) {
3263 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3264 offsets, [](Value v) -> OpFoldResult { return v; });
3265 SmallVector<OpFoldResult> sizeValues =
3266 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3267 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3268 strides, [](Value v) -> OpFoldResult { return v; });
3269 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3270}
3271
3272// Build a SubViewOp with dynamic entries and inferred result type.
3273void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3274 ValueRange offsets, ValueRange sizes, ValueRange strides,
3275 ArrayRef<NamedAttribute> attrs) {
3276 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3277}
3278
3279/// For ViewLikeOpInterface.
3280Value SubViewOp::getViewSource() { return getSource(); }
3281
3282/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3283/// static value).
3284static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3285 int64_t t1Offset, t2Offset;
3286 SmallVector<int64_t> t1Strides, t2Strides;
3287 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3288 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3289 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3290}
3291
3292/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3293/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3294/// marked as dropped in `droppedDims`.
3295static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3296 const llvm::SmallBitVector &droppedDims) {
3297 assert(size_t(t1.getRank()) == droppedDims.size() &&
3298 "incorrect number of bits");
3299 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3300 "incorrect number of dropped dims");
3301 int64_t t1Offset, t2Offset;
3302 SmallVector<int64_t> t1Strides, t2Strides;
3303 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3304 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3305 if (failed(res1) || failed(res2))
3306 return false;
3307 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3308 if (droppedDims[i])
3309 continue;
3310 if (t1Strides[i] != t2Strides[j])
3311 return false;
3312 ++j;
3313 }
3314 return true;
3315}
3316
3318 SubViewOp op, Type expectedType) {
3319 auto memrefType = llvm::cast<ShapedType>(expectedType);
3320 switch (result) {
3322 return success();
3324 return op->emitError("expected result rank to be smaller or equal to ")
3325 << "the source rank, but got " << op.getType();
3327 return op->emitError("expected result type to be ")
3328 << expectedType
3329 << " or a rank-reduced version. (mismatch of result sizes), but got "
3330 << op.getType();
3332 return op->emitError("expected result element type to be ")
3333 << memrefType.getElementType() << ", but got " << op.getType();
3335 return op->emitError(
3336 "expected result and source memory spaces to match, but got ")
3337 << op.getType();
3339 return op->emitError("expected result type to be ")
3340 << expectedType
3341 << " or a rank-reduced version. (mismatch of result layout), but "
3342 "got "
3343 << op.getType();
3344 }
3345 llvm_unreachable("unexpected subview verification result");
3346}
3347
3348/// Verifier for SubViewOp.
3349LogicalResult SubViewOp::verify() {
3350 MemRefType baseType = getSourceType();
3351 MemRefType subViewType = getType();
3352 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3353 ArrayRef<int64_t> staticSizes = getStaticSizes();
3354 ArrayRef<int64_t> staticStrides = getStaticStrides();
3355
3356 // The base memref and the view memref should be in the same memory space.
3357 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3358 return emitError("different memory spaces specified for base memref "
3359 "type ")
3360 << baseType << " and subview memref type " << subViewType;
3361
3362 // Verify that the base memref type has a strided layout map.
3363 if (!baseType.isStrided())
3364 return emitError("base type ") << baseType << " is not strided";
3365
3366 // Compute the expected result type, assuming that there are no rank
3367 // reductions.
3368 MemRefType expectedType = SubViewOp::inferResultType(
3369 baseType, staticOffsets, staticSizes, staticStrides);
3370
3371 // Verify all properties of a shaped type: rank, element type and dimension
3372 // sizes. This takes into account potential rank reductions.
3373 auto shapedTypeVerification = isRankReducedType(
3374 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3375 if (shapedTypeVerification != SliceVerificationResult::Success)
3376 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3377
3378 // Make sure that the memory space did not change.
3379 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3381 *this, expectedType);
3382
3383 // Verify the offset of the layout map.
3384 if (!haveCompatibleOffsets(expectedType, subViewType))
3386 *this, expectedType);
3387
3388 // The only thing that's left to verify now are the strides. First, compute
3389 // the unused dimensions due to rank reductions. We have to look at sizes and
3390 // strides to decide which dimensions were dropped. This function also
3391 // partially verifies strides in case of rank reductions.
3392 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3393 getMixedSizes());
3394 if (failed(unusedDims))
3396 *this, expectedType);
3397
3398 // Strides must match.
3399 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3401 *this, expectedType);
3402
3403 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3404 // to the base memref.
3405 SliceBoundsVerificationResult boundsResult =
3406 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3407 staticStrides, /*generateErrorMessage=*/true);
3408 if (!boundsResult.isValid)
3409 return getOperation()->emitError(boundsResult.errorMessage);
3410
3411 return success();
3412}
3413
3415 return os << "range " << range.offset << ":" << range.size << ":"
3416 << range.stride;
3417}
3418
3419/// Return the list of Range (i.e. offset, size, stride). Each Range
3420/// entry contains either the dynamic value or a ConstantIndexOp constructed
3421/// with `b` at location `loc`.
3422SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3423 OpBuilder &b, Location loc) {
3424 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3425 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3426 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3428 unsigned rank = ranks[0];
3429 res.reserve(rank);
3430 for (unsigned idx = 0; idx < rank; ++idx) {
3431 Value offset =
3432 op.isDynamicOffset(idx)
3433 ? op.getDynamicOffset(idx)
3434 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3435 Value size =
3436 op.isDynamicSize(idx)
3437 ? op.getDynamicSize(idx)
3438 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3439 Value stride =
3440 op.isDynamicStride(idx)
3441 ? op.getDynamicStride(idx)
3442 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3443 res.emplace_back(Range{offset, size, stride});
3444 }
3445 return res;
3446}
3447
3448/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3449/// to deduce the result type for the given `sourceType`. Additionally, reduce
3450/// the rank of the inferred result type if `currentResultType` is lower rank
3451/// than `currentSourceType`. Use this signature if `sourceType` is updated
3452/// together with the result type. In this case, it is important to compute
3453/// the dropped dimensions using `currentSourceType` whose strides align with
3454/// `currentResultType`.
3456 MemRefType currentResultType, MemRefType currentSourceType,
3457 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3458 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3459 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3460 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3461 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3462 currentSourceType, currentResultType, mixedSizes);
3463 if (failed(unusedDims))
3464 return nullptr;
3465
3466 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3467 SmallVector<int64_t> shape, strides;
3468 unsigned numDimsAfterReduction =
3469 nonRankReducedType.getRank() - unusedDims->count();
3470 shape.reserve(numDimsAfterReduction);
3471 strides.reserve(numDimsAfterReduction);
3472 for (const auto &[idx, size, stride] :
3473 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3474 nonRankReducedType.getShape(), layout.getStrides())) {
3475 if (unusedDims->test(idx))
3476 continue;
3477 shape.push_back(size);
3478 strides.push_back(stride);
3479 }
3480
3481 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3482 StridedLayoutAttr::get(sourceType.getContext(),
3483 layout.getOffset(), strides),
3484 nonRankReducedType.getMemorySpace());
3485}
3486
3488 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3489 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3490 unsigned rank = memrefType.getRank();
3491 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3493 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3494 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3495 targetShape, memrefType, offsets, sizes, strides);
3496 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3497 sizes, strides);
3498}
3499
3500FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3501 Value value,
3502 ArrayRef<int64_t> desiredShape) {
3503 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3504 assert(sourceMemrefType && "not a ranked memref type");
3505 auto sourceShape = sourceMemrefType.getShape();
3506 if (sourceShape.equals(desiredShape))
3507 return value;
3508 auto maybeRankReductionMask =
3509 mlir::computeRankReductionMask(sourceShape, desiredShape);
3510 if (!maybeRankReductionMask)
3511 return failure();
3512 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3513}
3514
3515/// Helper method to check if a `subview` operation is trivially a no-op. This
3516/// is the case if the all offsets are zero, all strides are 1, and the source
3517/// shape is same as the size of the subview. In such cases, the subview can
3518/// be folded into its source.
3519static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3520 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3521 return false;
3522
3523 auto mixedOffsets = subViewOp.getMixedOffsets();
3524 auto mixedSizes = subViewOp.getMixedSizes();
3525 auto mixedStrides = subViewOp.getMixedStrides();
3526
3527 // Check offsets are zero.
3528 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3529 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3530 return !intValue || intValue.value() != 0;
3531 }))
3532 return false;
3533
3534 // Check strides are one.
3535 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3536 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3537 return !intValue || intValue.value() != 1;
3538 }))
3539 return false;
3540
3541 // Check all size values are static and matches the (static) source shape.
3542 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3543 for (const auto &size : llvm::enumerate(mixedSizes)) {
3544 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3545 if (!intValue || *intValue != sourceShape[size.index()])
3546 return false;
3547 }
3548 // All conditions met. The `SubViewOp` is foldable as a no-op.
3549 return true;
3550}
3551
3552namespace {
3553/// Pattern to rewrite a subview op with MemRefCast arguments.
3554/// This essentially pushes memref.cast past its consuming subview when
3555/// `canFoldIntoConsumerOp` is true.
3556///
3557/// Example:
3558/// ```
3559/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3560/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3561/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3562/// ```
3563/// is rewritten into:
3564/// ```
3565/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3566/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3567/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3568/// ```
3569class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3570public:
3571 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3572
3573 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3574 PatternRewriter &rewriter) const override {
3575 // Any constant operand, just return to let SubViewOpConstantFolder kick
3576 // in.
3577 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3578 return matchPattern(operand, matchConstantIndex());
3579 }))
3580 return failure();
3581
3582 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3583 if (!castOp)
3584 return failure();
3585
3586 if (!CastOp::canFoldIntoConsumerOp(castOp))
3587 return failure();
3588
3589 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3590 // the MemRefCastOp source operand type to infer the result type and the
3591 // current SubViewOp source operand type to compute the dropped dimensions
3592 // if the operation is rank-reducing.
3593 auto resultType = getCanonicalSubViewResultType(
3594 subViewOp.getType(), subViewOp.getSourceType(),
3595 llvm::cast<MemRefType>(castOp.getSource().getType()),
3596 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3597 subViewOp.getMixedStrides());
3598 if (!resultType)
3599 return failure();
3600
3601 Value newSubView = SubViewOp::create(
3602 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3603 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3604 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3605 subViewOp.getStaticStrides());
3606 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3607 newSubView);
3608 return success();
3609 }
3610};
3611
3612/// Canonicalize subview ops that are no-ops. When the source shape is not
3613/// same as a result shape due to use of `affine_map`.
3614class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3615public:
3616 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3617
3618 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3619 PatternRewriter &rewriter) const override {
3620 if (!isTrivialSubViewOp(subViewOp))
3621 return failure();
3622 if (subViewOp.getSourceType() == subViewOp.getType()) {
3623 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3624 return success();
3625 }
3626 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3627 subViewOp.getSource());
3628 return success();
3629 }
3630};
3631} // namespace
3632
3633/// Return the canonical type of the result of a subview.
3635 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3636 ArrayRef<OpFoldResult> mixedSizes,
3637 ArrayRef<OpFoldResult> mixedStrides) {
3638 // Infer a memref type without taking into account any rank reductions.
3639 MemRefType resTy = SubViewOp::inferResultType(
3640 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3641 if (!resTy)
3642 return {};
3643 MemRefType nonReducedType = resTy;
3644
3645 // Directly return the non-rank reduced type if there are no dropped dims.
3646 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3647 if (droppedDims.none())
3648 return nonReducedType;
3649
3650 // Take the strides and offset from the non-rank reduced type.
3651 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3652
3653 // Drop dims from shape and strides.
3654 SmallVector<int64_t> targetShape;
3655 SmallVector<int64_t> targetStrides;
3656 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3657 if (droppedDims.test(i))
3658 continue;
3659 targetStrides.push_back(nonReducedStrides[i]);
3660 targetShape.push_back(nonReducedType.getDimSize(i));
3661 }
3662
3663 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3664 StridedLayoutAttr::get(nonReducedType.getContext(),
3665 offset, targetStrides),
3666 nonReducedType.getMemorySpace());
3667 }
3668};
3669
3670/// A canonicalizer wrapper to replace SubViewOps.
3672 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3673 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3674 }
3675};
3676
3677void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3678 MLIRContext *context) {
3679 results
3680 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3681 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3682 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3683}
3684
3685OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3686 MemRefType sourceMemrefType = getSource().getType();
3687 MemRefType resultMemrefType = getResult().getType();
3688 auto resultLayout =
3689 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3690
3691 if (resultMemrefType == sourceMemrefType &&
3692 resultMemrefType.hasStaticShape() &&
3693 (!resultLayout || resultLayout.hasStaticLayout())) {
3694 return getViewSource();
3695 }
3696
3697 // Fold subview(subview(x)), where both subviews have the same size and the
3698 // second subview's offsets are all zero. (I.e., the second subview is a
3699 // no-op.)
3700 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3701 auto srcSizes = srcSubview.getMixedSizes();
3702 auto sizes = getMixedSizes();
3703 auto offsets = getMixedOffsets();
3704 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3705 auto strides = getMixedStrides();
3706 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3707 bool allSizesSame = llvm::equal(sizes, srcSizes);
3708 if (allOffsetsZero && allStridesOne && allSizesSame &&
3709 resultMemrefType == sourceMemrefType)
3710 return getViewSource();
3711 }
3712
3713 return {};
3714}
3715
3716FailureOr<std::optional<SmallVector<Value>>>
3717SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3718 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3719}
3720
3721void SubViewOp::inferStridedMetadataRanges(
3722 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3723 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3724 auto isUninitialized =
3725 +[](IntegerValueRange range) { return range.isUninitialized(); };
3726
3727 // Bail early if any of the operands metadata is not ready:
3728 SmallVector<IntegerValueRange> offsetOperands =
3729 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3730 if (llvm::any_of(offsetOperands, isUninitialized))
3731 return;
3732
3733 SmallVector<IntegerValueRange> sizeOperands =
3734 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3735 if (llvm::any_of(sizeOperands, isUninitialized))
3736 return;
3737
3738 SmallVector<IntegerValueRange> stridesOperands =
3739 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3740 if (llvm::any_of(stridesOperands, isUninitialized))
3741 return;
3742
3743 StridedMetadataRange sourceRange =
3744 ranges[getSourceMutable().getOperandNumber()];
3745 if (sourceRange.isUninitialized())
3746 return;
3747
3748 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3749
3750 // Get the dropped dims.
3751 llvm::SmallBitVector droppedDims = getDroppedDims();
3752
3753 // Compute the new offset, strides and sizes.
3754 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3755 SmallVector<ConstantIntRanges> strides, sizes;
3756
3757 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3758 bool dropped = droppedDims.test(i);
3759 // Compute the new offset.
3760 ConstantIntRanges off =
3761 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3762 offset = intrange::inferAdd({offset, off});
3763
3764 // Skip dropped dimensions.
3765 if (dropped)
3766 continue;
3767 // Multiply the strides.
3768 strides.push_back(
3769 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3770 // Get the sizes.
3771 sizes.push_back(sizeOperands[i].getValue());
3772 }
3773
3774 setMetadata(getResult(),
3776 SmallVector<ConstantIntRanges>({std::move(offset)}),
3777 std::move(sizes), std::move(strides)));
3778}
3779
3780//===----------------------------------------------------------------------===//
3781// TransposeOp
3782//===----------------------------------------------------------------------===//
3783
3784void TransposeOp::getAsmResultNames(
3785 function_ref<void(Value, StringRef)> setNameFn) {
3786 setNameFn(getResult(), "transpose");
3787}
3788
3789/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3790static MemRefType inferTransposeResultType(MemRefType memRefType,
3791 AffineMap permutationMap) {
3792 auto originalSizes = memRefType.getShape();
3793 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3794 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3795
3796 // Compute permuted sizes and strides.
3797 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3798 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3799
3800 return MemRefType::Builder(memRefType)
3801 .setShape(sizes)
3802 .setLayout(
3803 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3804}
3805
3806void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3807 AffineMapAttr permutation,
3808 ArrayRef<NamedAttribute> attrs) {
3809 auto permutationMap = permutation.getValue();
3810 assert(permutationMap);
3811
3812 auto memRefType = llvm::cast<MemRefType>(in.getType());
3813 // Compute result type.
3814 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3815
3816 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3817 build(b, result, resultType, in, attrs);
3818}
3819
3820// transpose $in $permutation attr-dict : type($in) `to` type(results)
3821void TransposeOp::print(OpAsmPrinter &p) {
3822 p << " " << getIn() << " " << getPermutation();
3823 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3824 p << " : " << getIn().getType() << " to " << getType();
3825}
3826
3827ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3828 OpAsmParser::UnresolvedOperand in;
3829 AffineMap permutation;
3830 MemRefType srcType, dstType;
3831 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3832 parser.parseOptionalAttrDict(result.attributes) ||
3833 parser.parseColonType(srcType) ||
3834 parser.resolveOperand(in, srcType, result.operands) ||
3835 parser.parseKeywordType("to", dstType) ||
3836 parser.addTypeToList(dstType, result.types))
3837 return failure();
3838
3839 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3840 AffineMapAttr::get(permutation));
3841 return success();
3842}
3843
3844LogicalResult TransposeOp::verify() {
3845 if (!getPermutation().isPermutation())
3846 return emitOpError("expected a permutation map");
3847 if (getPermutation().getNumDims() != getIn().getType().getRank())
3848 return emitOpError("expected a permutation map of same rank as the input");
3849
3850 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3851 auto resultType = llvm::cast<MemRefType>(getType());
3852 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3853 .canonicalizeStridedLayout();
3854
3855 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3856 return emitOpError("result type ")
3857 << resultType
3858 << " is not equivalent to the canonical transposed input type "
3859 << canonicalResultType;
3860 return success();
3861}
3862
3863OpFoldResult TransposeOp::fold(FoldAdaptor) {
3864 // First check for identity permutation, we can fold it away if input and
3865 // result types are identical already.
3866 if (getPermutation().isIdentity() && getType() == getIn().getType())
3867 return getIn();
3868 // Fold two consecutive memref.transpose Ops into one by composing their
3869 // permutation maps.
3870 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3871 AffineMap composedPermutation =
3872 getPermutation().compose(otherTransposeOp.getPermutation());
3873 getInMutable().assign(otherTransposeOp.getIn());
3874 setPermutation(composedPermutation);
3875 return getResult();
3876 }
3877 return {};
3878}
3879
3880FailureOr<std::optional<SmallVector<Value>>>
3881TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3882 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3883}
3884
3885//===----------------------------------------------------------------------===//
3886// ViewOp
3887//===----------------------------------------------------------------------===//
3888
3889void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3890 setNameFn(getResult(), "view");
3891}
3892
3893LogicalResult ViewOp::verify() {
3894 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3895 auto viewType = getType();
3896
3897 // The base memref should have identity layout map (or none).
3898 if (!baseType.getLayout().isIdentity())
3899 return emitError("unsupported map for base memref type ") << baseType;
3900
3901 // The result memref should have identity layout map (or none).
3902 if (!viewType.getLayout().isIdentity())
3903 return emitError("unsupported map for result memref type ") << viewType;
3904
3905 // The base memref and the view memref should be in the same memory space.
3906 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3907 return emitError("different memory spaces specified for base memref "
3908 "type ")
3909 << baseType << " and view memref type " << viewType;
3910
3911 // Verify that we have the correct number of sizes for the result type.
3912 if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
3913 return failure();
3914
3915 return success();
3916}
3917
3918Value ViewOp::getViewSource() { return getSource(); }
3919
3920OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3921 MemRefType sourceMemrefType = getSource().getType();
3922 MemRefType resultMemrefType = getResult().getType();
3923
3924 if (resultMemrefType == sourceMemrefType &&
3925 resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
3926 return getViewSource();
3927
3928 return {};
3929}
3930
3931SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3932 SmallVector<OpFoldResult> result;
3933 unsigned ctr = 0;
3934 Builder b(getContext());
3935 for (int64_t dim : getType().getShape()) {
3936 if (ShapedType::isDynamic(dim)) {
3937 result.push_back(getSizes()[ctr++]);
3938 } else {
3939 result.push_back(b.getIndexAttr(dim));
3940 }
3941 }
3942 return result;
3943}
3944
3945namespace {
3946/// Given a memref type and a range of values that defines its dynamic
3947/// dimension sizes, turn all dynamic sizes that have a constant value into
3948/// static dimension sizes.
3949static MemRefType
3950foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
3951 SmallVectorImpl<Value> &foldedDynamicSizes) {
3952 SmallVector<int64_t> staticShape(type.getShape());
3953 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3954 "incorrect number of dynamic sizes");
3955
3956 // Compute new static and dynamic sizes.
3957 unsigned ctr = 0;
3958 for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3959 if (ShapedType::isStatic(dimSize))
3960 continue;
3961
3962 Value dynamicSize = dynamicSizes[ctr++];
3963 if (auto cst = getConstantIntValue(dynamicSize)) {
3964 // Dynamic size must be non-negative.
3965 if (cst.value() < 0) {
3966 foldedDynamicSizes.push_back(dynamicSize);
3967 continue;
3968 }
3969 staticShape[dim] = cst.value();
3970 } else {
3971 foldedDynamicSizes.push_back(dynamicSize);
3972 }
3973 }
3974
3975 return MemRefType::Builder(type).setShape(staticShape);
3976}
3977
3978/// Change the result type of a `memref.view` by making originally dynamic
3979/// dimensions static when their sizes come from `constant` ops.
3980/// Example:
3981/// ```
3982/// %c5 = arith.constant 5: index
3983/// %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
3984/// ```
3985/// to
3986/// ```
3987/// %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
3988/// ```
3989struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3990 using Base::Base;
3991
3992 LogicalResult matchAndRewrite(ViewOp viewOp,
3993 PatternRewriter &rewriter) const override {
3994 SmallVector<Value> foldedDynamicSizes;
3995 MemRefType resultType = viewOp.getType();
3996 MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
3997 resultType, viewOp.getSizes(), foldedDynamicSizes);
3998
3999 // Stop here if no dynamic size was promoted to static.
4000 if (foldedMemRefType == resultType)
4001 return failure();
4002
4003 // Create new ViewOp.
4004 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
4005 viewOp.getSource(), viewOp.getByteShift(),
4006 foldedDynamicSizes);
4007 // Insert a cast so we have the same type as the old memref type.
4008 rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
4009 return success();
4010 }
4011};
4012
4013/// view(memref.cast(%source)) -> view(%source).
4014struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
4015 using Base::Base;
4016
4017 LogicalResult matchAndRewrite(ViewOp viewOp,
4018 PatternRewriter &rewriter) const override {
4019 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
4020 if (!memrefCastOp)
4021 return failure();
4022
4023 rewriter.replaceOpWithNewOp<ViewOp>(
4024 viewOp, viewOp.getType(), memrefCastOp.getSource(),
4025 viewOp.getByteShift(), viewOp.getSizes());
4026 return success();
4027 }
4028};
4029} // namespace
4030
4031void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
4032 MLIRContext *context) {
4033 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
4034}
4035
4036FailureOr<std::optional<SmallVector<Value>>>
4037ViewOp::bubbleDownCasts(OpBuilder &builder) {
4038 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
4039}
4040
4041//===----------------------------------------------------------------------===//
4042// AtomicRMWOp
4043//===----------------------------------------------------------------------===//
4044
4045LogicalResult AtomicRMWOp::verify() {
4046 switch (getKind()) {
4047 case arith::AtomicRMWKind::addf:
4048 case arith::AtomicRMWKind::maximumf:
4049 case arith::AtomicRMWKind::minimumf:
4050 case arith::AtomicRMWKind::mulf:
4051 if (!llvm::isa<FloatType>(getValue().getType()))
4052 return emitOpError() << "with kind '"
4053 << arith::stringifyAtomicRMWKind(getKind())
4054 << "' expects a floating-point type";
4055 break;
4056 case arith::AtomicRMWKind::addi:
4057 case arith::AtomicRMWKind::maxs:
4058 case arith::AtomicRMWKind::maxu:
4059 case arith::AtomicRMWKind::mins:
4060 case arith::AtomicRMWKind::minu:
4061 case arith::AtomicRMWKind::muli:
4062 case arith::AtomicRMWKind::ori:
4063 case arith::AtomicRMWKind::xori:
4064 case arith::AtomicRMWKind::andi:
4065 if (!llvm::isa<IntegerType>(getValue().getType()))
4066 return emitOpError() << "with kind '"
4067 << arith::stringifyAtomicRMWKind(getKind())
4068 << "' expects an integer type";
4069 break;
4070 default:
4071 break;
4072 }
4073 return success();
4074}
4075
4076OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4077 /// atomicrmw(memrefcast) -> atomicrmw
4078 if (succeeded(foldMemRefCast(*this, getValue())))
4079 return getResult();
4080 return OpFoldResult();
4081}
4082
4083FailureOr<std::optional<SmallVector<Value>>>
4084AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4086 getResult());
4087}
4088
4089TypedValue<MemRefType> AtomicRMWOp::getAccessedMemref() { return getMemref(); }
4090
4091std::optional<SmallVector<Value>>
4092AtomicRMWOp::updateMemrefAndIndices(RewriterBase &rewriter, Value newMemref,
4093 ValueRange newIndices) {
4094 rewriter.modifyOpInPlace(*this, [&]() {
4095 getMemrefMutable().assign(newMemref);
4096 getIndicesMutable().assign(newIndices);
4097 });
4098 return std::nullopt;
4099}
4100
4101//===----------------------------------------------------------------------===//
4102// TableGen'd op method definitions
4103//===----------------------------------------------------------------------===//
4104
4105#define GET_OP_CLASSES
4106#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:98
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IndexType getIndexType()
Definition Builders.cpp:55
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:409
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.