MLIR 23.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/IR/AffineMap.h"
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/Matchers.h"
25#include "llvm/ADT/STLExtras.h"
26#include "llvm/ADT/SmallBitVector.h"
27#include "llvm/ADT/SmallVectorExtras.h"
28
29using namespace mlir;
30using namespace mlir::memref;
31
32/// Materialize a single constant operation from a given attribute value with
33/// the desired resultant type.
34Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
35 Attribute value, Type type,
36 Location loc) {
37 return arith::ConstantOp::materialize(builder, value, type, loc);
38}
39
40//===----------------------------------------------------------------------===//
41// Common canonicalization pattern support logic
42//===----------------------------------------------------------------------===//
43
44/// This is a common class used for patterns of the form
45/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
46/// into the root operation directly.
47LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
48 bool folded = false;
49 for (OpOperand &operand : op->getOpOperands()) {
50 auto cast = operand.get().getDefiningOp<CastOp>();
51 if (cast && operand.get() != inner &&
52 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
53 operand.set(cast.getOperand());
54 folded = true;
55 }
56 }
57 return success(folded);
58}
59
60/// Return an unranked/ranked tensor type for the given unranked/ranked memref
61/// type.
63 if (auto memref = llvm::dyn_cast<MemRefType>(type))
64 return RankedTensorType::get(memref.getShape(), memref.getElementType());
65 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
66 return UnrankedTensorType::get(memref.getElementType());
67 return NoneType::get(type.getContext());
68}
69
71 int64_t dim) {
72 auto memrefType = llvm::cast<MemRefType>(value.getType());
73 if (memrefType.isDynamicDim(dim))
74 return builder.createOrFold<memref::DimOp>(loc, value, dim);
75
76 return builder.getIndexAttr(memrefType.getDimSize(dim));
77}
78
80 Location loc, Value value) {
81 auto memrefType = llvm::cast<MemRefType>(value.getType());
83 for (int64_t i = 0; i < memrefType.getRank(); ++i)
84 result.push_back(getMixedSize(builder, loc, value, i));
85 return result;
86}
87
88//===----------------------------------------------------------------------===//
89// Utility functions for propagating static information
90//===----------------------------------------------------------------------===//
91
92/// Helper function that sets values[i] to constValues[i] if the latter is a
93/// static value, as indicated by ShapedType::kDynamic.
94///
95/// If constValues[i] is dynamic, tries to extract a constant value from
96/// value[i] to allow for additional folding opportunities. Also convertes all
97/// existing attributes to index attributes. (They may be i64 attributes.)
99 ArrayRef<int64_t> constValues) {
100 assert(constValues.size() == values.size() &&
101 "incorrect number of const values");
102 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
103 Builder builder(values[i].getContext());
104 if (ShapedType::isStatic(cstVal)) {
105 // Constant value is known, use it directly.
106 values[i] = builder.getIndexAttr(cstVal);
107 continue;
108 }
109 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
110 // Try to extract a constant or convert an existing to index.
111 values[i] = builder.getIndexAttr(*cst);
112 }
113 }
114}
115
116/// Helper function to retrieve a lossless memory-space cast, and the
117/// corresponding new result memref type.
118static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
120 MemorySpaceCastOpInterface castOp =
121 MemorySpaceCastOpInterface::getIfPromotableCast(src);
122
123 // Bail if the cast is not lossless.
124 if (!castOp)
125 return {};
126
127 // Transform the source and target type of `castOp` to have the same metadata
128 // as `resultTy`. Bail if not possible.
129 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
130 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
131 if (failed(srcTy))
132 return {};
133
134 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
135 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
136 if (failed(tgtTy))
137 return {};
138
139 // Check if this is a valid memory-space cast.
140 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
141 return {};
142
143 return std::make_tuple(castOp, *tgtTy, *srcTy);
144}
145
146/// Implementation of `bubbleDownCasts` method for memref operations that
147/// return a single memref result.
148template <typename ConcreteOpTy>
149static FailureOr<std::optional<SmallVector<Value>>>
151 OpOperand &src) {
152 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
153 // Bail if we cannot cast.
154 if (!castOp)
155 return failure();
156
157 // Create the new operands.
158 SmallVector<Value> operands;
159 llvm::append_range(operands, op->getOperands());
160 operands[src.getOperandNumber()] = castOp.getSourcePtr();
161
162 // Create the new op and results.
163 auto newOp = ConcreteOpTy::create(
164 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
165 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
166
167 // Insert a memory-space cast to the original memory space of the op.
168 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
169 builder, tgtTy,
170 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
171 return std::optional<SmallVector<Value>>(
172 SmallVector<Value>({result.getTargetPtr()}));
173}
174
175//===----------------------------------------------------------------------===//
176// AllocOp / AllocaOp
177//===----------------------------------------------------------------------===//
178
179void AllocOp::getAsmResultNames(
180 function_ref<void(Value, StringRef)> setNameFn) {
181 setNameFn(getResult(), "alloc");
182}
183
184void AllocaOp::getAsmResultNames(
185 function_ref<void(Value, StringRef)> setNameFn) {
186 setNameFn(getResult(), "alloca");
187}
188
189template <typename AllocLikeOp>
190static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
191 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
192 "applies to only alloc or alloca");
193 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
194 if (!memRefType)
195 return op.emitOpError("result must be a memref");
196
197 if (failed(verifyDynamicDimensionCount(op, memRefType, op.getDynamicSizes())))
198 return failure();
199
200 unsigned numSymbols = 0;
201 if (!memRefType.getLayout().isIdentity())
202 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
203 if (op.getSymbolOperands().size() != numSymbols)
204 return op.emitOpError("symbol operand count does not equal memref symbol "
205 "count: expected ")
206 << numSymbols << ", got " << op.getSymbolOperands().size();
207
208 return success();
209}
210
211LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
212
213LogicalResult AllocaOp::verify() {
214 // An alloca op needs to have an ancestor with an allocation scope trait.
215 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
216 return emitOpError(
217 "requires an ancestor op with AutomaticAllocationScope trait");
218
219 return verifyAllocLikeOp(*this);
220}
221
222namespace {
223/// Fold constant dimensions into an alloc like operation.
224template <typename AllocLikeOp>
225struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
226 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
227
228 LogicalResult matchAndRewrite(AllocLikeOp alloc,
229 PatternRewriter &rewriter) const override {
230 // Check to see if any dimensions operands are constants. If so, we can
231 // substitute and drop them.
232 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
233 APInt constSizeArg;
234 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
235 return false;
236 return constSizeArg.isNonNegative();
237 }))
238 return failure();
239
240 auto memrefType = alloc.getType();
241
242 // Ok, we have one or more constant operands. Collect the non-constant ones
243 // and keep track of the resultant memref type to build.
244 SmallVector<int64_t, 4> newShapeConstants;
245 newShapeConstants.reserve(memrefType.getRank());
246 SmallVector<Value, 4> dynamicSizes;
247
248 unsigned dynamicDimPos = 0;
249 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
250 int64_t dimSize = memrefType.getDimSize(dim);
251 // If this is already static dimension, keep it.
252 if (ShapedType::isStatic(dimSize)) {
253 newShapeConstants.push_back(dimSize);
254 continue;
255 }
256 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
257 APInt constSizeArg;
258 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
259 constSizeArg.isNonNegative()) {
260 // Dynamic shape dimension will be folded.
261 newShapeConstants.push_back(constSizeArg.getZExtValue());
262 } else {
263 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
264 newShapeConstants.push_back(ShapedType::kDynamic);
265 dynamicSizes.push_back(dynamicSize);
266 }
267 dynamicDimPos++;
268 }
269
270 // Create new memref type (which will have fewer dynamic dimensions).
271 MemRefType newMemRefType =
272 MemRefType::Builder(memrefType).setShape(newShapeConstants);
273 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
274
275 // Create and insert the alloc op for the new memref.
276 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
277 dynamicSizes, alloc.getSymbolOperands(),
278 alloc.getAlignmentAttr());
279 // Insert a cast so we have the same type as the old alloc.
280 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
281 return success();
282 }
283};
284
285/// Fold alloc operations with no users or only store and dealloc uses.
286template <typename T>
287struct SimplifyDeadAlloc : public OpRewritePattern<T> {
288 using OpRewritePattern<T>::OpRewritePattern;
289
290 LogicalResult matchAndRewrite(T alloc,
291 PatternRewriter &rewriter) const override {
292 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
293 if (auto storeOp = dyn_cast<StoreOp>(op))
294 return storeOp.getValue() == alloc;
295 return !isa<DeallocOp>(op);
296 }))
297 return failure();
298
299 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
300 rewriter.eraseOp(user);
301
302 rewriter.eraseOp(alloc);
303 return success();
304 }
305};
306} // namespace
307
308void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
311}
312
313void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
314 MLIRContext *context) {
315 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
316 context);
317}
318
319//===----------------------------------------------------------------------===//
320// ReallocOp
321//===----------------------------------------------------------------------===//
322
323LogicalResult ReallocOp::verify() {
324 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
325 MemRefType resultType = getType();
326
327 // The source memref should have identity layout (or none).
328 if (!sourceType.getLayout().isIdentity())
329 return emitError("unsupported layout for source memref type ")
330 << sourceType;
331
332 // The result memref should have identity layout (or none).
333 if (!resultType.getLayout().isIdentity())
334 return emitError("unsupported layout for result memref type ")
335 << resultType;
336
337 // The source memref and the result memref should be in the same memory space.
338 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
339 return emitError("different memory spaces specified for source memref "
340 "type ")
341 << sourceType << " and result memref type " << resultType;
342
343 // The source memref and the result memref should have the same element type.
344 if (failed(verifyElementTypesMatch(*this, sourceType, resultType, "source",
345 "result")))
346 return failure();
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor::parent());
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415ValueRange AllocaScopeOp::getSuccessorInputs(RegionSuccessor successor) {
416 return successor.isParent() ? ValueRange(getResults()) : ValueRange();
417}
418
419/// Given an operation, return whether this op is guaranteed to
420/// allocate an AutomaticAllocationScopeResource
422 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
423 if (!interface)
424 return false;
425 for (auto res : op->getResults()) {
426 if (auto effect =
427 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
428 if (isa<SideEffects::AutomaticAllocationScopeResource>(
429 effect->getResource()))
430 return true;
431 }
432 }
433 return false;
434}
435
436/// Given an operation, return whether this op itself could
437/// allocate an AutomaticAllocationScopeResource. Note that
438/// this will not check whether an operation contained within
439/// the op can allocate.
441 // This op itself doesn't create a stack allocation,
442 // the inner allocation should be handled separately.
444 return false;
445 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
446 if (!interface)
447 return true;
448 for (auto res : op->getResults()) {
449 if (auto effect =
450 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
451 if (isa<SideEffects::AutomaticAllocationScopeResource>(
452 effect->getResource()))
453 return true;
454 }
455 }
456 return false;
457}
458
459/// Return whether this op is the last non terminating op
460/// in a region. That is to say, it is in a one-block region
461/// and is only followed by a terminator. This prevents
462/// extending the lifetime of allocations.
464 return op->getBlock()->mightHaveTerminator() &&
465 op->getNextNode() == op->getBlock()->getTerminator() &&
467}
468
469/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
470/// or it contains no allocation.
471struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
472 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
473
474 LogicalResult matchAndRewrite(AllocaScopeOp op,
475 PatternRewriter &rewriter) const override {
476 bool hasPotentialAlloca =
477 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
478 if (alloc == op)
479 return WalkResult::advance();
481 return WalkResult::interrupt();
482 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
483 return WalkResult::skip();
484 return WalkResult::advance();
485 }).wasInterrupted();
486
487 // If this contains no potential allocation, it is always legal to
488 // inline. Otherwise, consider two conditions:
489 if (hasPotentialAlloca) {
490 // If the parent isn't an allocation scope, or we are not the last
491 // non-terminator op in the parent, we will extend the lifetime.
492 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
493 return failure();
495 return failure();
496 }
497
498 Block *block = &op.getRegion().front();
499 Operation *terminator = block->getTerminator();
500 ValueRange results = terminator->getOperands();
501 rewriter.inlineBlockBefore(block, op);
502 rewriter.replaceOp(op, results);
503 rewriter.eraseOp(terminator);
504 return success();
505 }
506};
507
508/// Move allocations into an allocation scope, if it is legal to
509/// move them (e.g. their operands are available at the location
510/// the op would be moved to).
511struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
512 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
513
514 LogicalResult matchAndRewrite(AllocaScopeOp op,
515 PatternRewriter &rewriter) const override {
516
517 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
518 return failure();
519
520 Operation *lastParentWithoutScope = op->getParentOp();
521
522 if (!lastParentWithoutScope ||
523 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
524 return failure();
525
526 // Only apply to if this is this last non-terminator
527 // op in the block (lest lifetime be extended) of a one
528 // block region
529 if (!lastNonTerminatorInRegion(op) ||
530 !lastNonTerminatorInRegion(lastParentWithoutScope))
531 return failure();
532
533 while (!lastParentWithoutScope->getParentOp()
535 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
536 if (!lastParentWithoutScope ||
537 !lastNonTerminatorInRegion(lastParentWithoutScope))
538 return failure();
539 }
540 assert(lastParentWithoutScope->getParentOp()
542
543 Region *containingRegion = nullptr;
544 for (auto &r : lastParentWithoutScope->getRegions()) {
545 if (r.isAncestor(op->getParentRegion())) {
546 assert(containingRegion == nullptr &&
547 "only one region can contain the op");
548 containingRegion = &r;
549 }
550 }
551 assert(containingRegion && "op must be contained in a region");
552
554 op->walk([&](Operation *alloc) {
556 return WalkResult::skip();
557
558 // If any operand is not defined before the location of
559 // lastParentWithoutScope (i.e. where we would hoist to), skip.
560 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
561 return containingRegion->isAncestor(v.getParentRegion());
562 }))
563 return WalkResult::skip();
564 toHoist.push_back(alloc);
565 return WalkResult::advance();
566 });
567
568 if (toHoist.empty())
569 return failure();
570 rewriter.setInsertionPoint(lastParentWithoutScope);
571 for (auto *op : toHoist) {
572 auto *cloned = rewriter.clone(*op);
573 rewriter.replaceOp(op, cloned->getResults());
574 }
575 return success();
576 }
577};
578
579void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
580 MLIRContext *context) {
581 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
582}
583
584//===----------------------------------------------------------------------===//
585// AssumeAlignmentOp
586//===----------------------------------------------------------------------===//
587
588LogicalResult AssumeAlignmentOp::verify() {
589 if (!llvm::isPowerOf2_32(getAlignment()))
590 return emitOpError("alignment must be power of 2");
591 return success();
592}
593
594void AssumeAlignmentOp::getAsmResultNames(
595 function_ref<void(Value, StringRef)> setNameFn) {
596 setNameFn(getResult(), "assume_align");
597}
598
599OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
600 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
601 if (!source)
602 return {};
603 if (source.getAlignment() != getAlignment())
604 return {};
605 return getMemref();
606}
607
608FailureOr<std::optional<SmallVector<Value>>>
609AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
610 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
611}
612
613FailureOr<OpFoldResult> AssumeAlignmentOp::reifyDimOfResult(OpBuilder &builder,
614 int resultIndex,
615 int dim) {
616 assert(resultIndex == 0 && "AssumeAlignmentOp has a single result");
617 return getMixedSize(builder, getLoc(), getMemref(), dim);
618}
619
620//===----------------------------------------------------------------------===//
621// DistinctObjectsOp
622//===----------------------------------------------------------------------===//
623
624LogicalResult DistinctObjectsOp::verify() {
625 if (getOperandTypes() != getResultTypes())
626 return emitOpError("operand types and result types must match");
627
628 if (getOperandTypes().empty())
629 return emitOpError("expected at least one operand");
630
631 return success();
632}
633
634LogicalResult DistinctObjectsOp::inferReturnTypes(
635 MLIRContext * /*context*/, std::optional<Location> /*location*/,
636 ValueRange operands, DictionaryAttr /*attributes*/,
637 OpaqueProperties /*properties*/, RegionRange /*regions*/,
638 SmallVectorImpl<Type> &inferredReturnTypes) {
639 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
640 return success();
641}
642
643//===----------------------------------------------------------------------===//
644// CastOp
645//===----------------------------------------------------------------------===//
646
647void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
648 setNameFn(getResult(), "cast");
649}
650
651/// Determines whether MemRef_CastOp casts to a more dynamic version of the
652/// source memref. This is useful to fold a memref.cast into a consuming op
653/// and implement canonicalization patterns for ops in different dialects that
654/// may consume the results of memref.cast operations. Such foldable memref.cast
655/// operations are typically inserted as `view` and `subview` ops are
656/// canonicalized, to preserve the type compatibility of their uses.
657///
658/// Returns true when all conditions are met:
659/// 1. source and result are ranked memrefs with strided semantics and same
660/// element type and rank.
661/// 2. each of the source's size, offset or stride has more static information
662/// than the corresponding result's size, offset or stride.
663///
664/// Example 1:
665/// ```mlir
666/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
667/// %2 = consumer %1 ... : memref<?x?xf32> ...
668/// ```
669///
670/// may fold into:
671///
672/// ```mlir
673/// %2 = consumer %0 ... : memref<8x16xf32> ...
674/// ```
675///
676/// Example 2:
677/// ```
678/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
679/// to memref<?x?xf32>
680/// consumer %1 : memref<?x?xf32> ...
681/// ```
682///
683/// may fold into:
684///
685/// ```
686/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
687/// ```
688bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
689 MemRefType sourceType =
690 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
691 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
692
693 // Requires ranked MemRefType.
694 if (!sourceType || !resultType)
695 return false;
696
697 // Requires same elemental type.
698 if (sourceType.getElementType() != resultType.getElementType())
699 return false;
700
701 // Requires same rank.
702 if (sourceType.getRank() != resultType.getRank())
703 return false;
704
705 // Only fold casts between strided memref forms.
706 int64_t sourceOffset, resultOffset;
707 SmallVector<int64_t, 4> sourceStrides, resultStrides;
708 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
709 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
710 return false;
711
712 // If cast is towards more static sizes along any dimension, don't fold.
713 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
714 auto ss = std::get<0>(it), st = std::get<1>(it);
715 if (ss != st)
716 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
717 return false;
718 }
719
720 // If cast is towards more static offset along any dimension, don't fold.
721 if (sourceOffset != resultOffset)
722 if (ShapedType::isDynamic(sourceOffset) &&
723 ShapedType::isStatic(resultOffset))
724 return false;
725
726 // If cast is towards more static strides along any dimension, don't fold.
727 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
728 auto ss = std::get<0>(it), st = std::get<1>(it);
729 if (ss != st)
730 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
731 return false;
732 }
733
734 return true;
735}
736
737bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
738 if (inputs.size() != 1 || outputs.size() != 1)
739 return false;
740 Type a = inputs.front(), b = outputs.front();
741 auto aT = llvm::dyn_cast<MemRefType>(a);
742 auto bT = llvm::dyn_cast<MemRefType>(b);
743
744 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
745 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
746
747 if (aT && bT) {
748 if (aT.getElementType() != bT.getElementType())
749 return false;
750 if (aT.getLayout() != bT.getLayout()) {
751 int64_t aOffset, bOffset;
752 SmallVector<int64_t, 4> aStrides, bStrides;
753 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
754 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
755 aStrides.size() != bStrides.size())
756 return false;
757
758 // Strides along a dimension/offset are compatible if the value in the
759 // source memref is static and the value in the target memref is the
760 // same. They are also compatible if either one is dynamic (see
761 // description of MemRefCastOp for details).
762 // Note that for dimensions of size 1, the stride can differ.
763 auto checkCompatible = [](int64_t a, int64_t b) {
764 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
765 };
766 if (!checkCompatible(aOffset, bOffset))
767 return false;
768 for (const auto &[index, aStride] : enumerate(aStrides)) {
769 if (aT.getDimSize(index) == 1 || bT.getDimSize(index) == 1)
770 continue;
771 if (!checkCompatible(aStride, bStrides[index]))
772 return false;
773 }
774 }
775 if (aT.getMemorySpace() != bT.getMemorySpace())
776 return false;
777
778 // They must have the same rank, and any specified dimensions must match.
779 if (aT.getRank() != bT.getRank())
780 return false;
781
782 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
783 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
784 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
785 aDim != bDim)
786 return false;
787 }
788 return true;
789 } else {
790 if (!aT && !uaT)
791 return false;
792 if (!bT && !ubT)
793 return false;
794 // Unranked to unranked casting is unsupported
795 if (uaT && ubT)
796 return false;
797
798 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
799 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
800 if (aEltType != bEltType)
801 return false;
802
803 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
804 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
805 return aMemSpace == bMemSpace;
806 }
807
808 return false;
809}
810
811OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
812 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
813}
814
815FailureOr<std::optional<SmallVector<Value>>>
816CastOp::bubbleDownCasts(OpBuilder &builder) {
817 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
818}
819
820//===----------------------------------------------------------------------===//
821// CopyOp
822//===----------------------------------------------------------------------===//
823
824namespace {
825
826/// Fold memref.copy(%x, %x).
827struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
828 using OpRewritePattern<CopyOp>::OpRewritePattern;
829
830 LogicalResult matchAndRewrite(CopyOp copyOp,
831 PatternRewriter &rewriter) const override {
832 if (copyOp.getSource() != copyOp.getTarget())
833 return failure();
834
835 rewriter.eraseOp(copyOp);
836 return success();
837 }
838};
839
840struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
841 using OpRewritePattern<CopyOp>::OpRewritePattern;
842
843 static bool isEmptyMemRef(BaseMemRefType type) {
844 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
845 }
846
847 LogicalResult matchAndRewrite(CopyOp copyOp,
848 PatternRewriter &rewriter) const override {
849 if (isEmptyMemRef(copyOp.getSource().getType()) ||
850 isEmptyMemRef(copyOp.getTarget().getType())) {
851 rewriter.eraseOp(copyOp);
852 return success();
853 }
854
855 return failure();
856 }
857};
858} // namespace
859
860void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
861 MLIRContext *context) {
862 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
863}
864
865/// If the source/target of a CopyOp is a CastOp that does not modify the shape
866/// and element type, the cast can be skipped. Such CastOps only cast the layout
867/// of the type.
868static LogicalResult foldCopyOfCast(CopyOp op) {
869 for (OpOperand &operand : op->getOpOperands()) {
870 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
871 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
872 operand.set(castOp.getOperand());
873 return success();
874 }
875 }
876 return failure();
877}
878
879LogicalResult CopyOp::fold(FoldAdaptor adaptor,
880 SmallVectorImpl<OpFoldResult> &results) {
881
882 /// copy(memrefcast) -> copy
883 return foldCopyOfCast(*this);
884}
885
886//===----------------------------------------------------------------------===//
887// DeallocOp
888//===----------------------------------------------------------------------===//
889
890LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
891 SmallVectorImpl<OpFoldResult> &results) {
892 /// dealloc(memrefcast) -> dealloc
893 return foldMemRefCast(*this);
894}
895
896//===----------------------------------------------------------------------===//
897// DimOp
898//===----------------------------------------------------------------------===//
899
900void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
901 setNameFn(getResult(), "dim");
902}
903
904void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
905 int64_t index) {
906 auto loc = result.location;
907 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
908 build(builder, result, source, indexValue);
909}
910
911std::optional<int64_t> DimOp::getConstantIndex() {
913}
914
915Speculation::Speculatability DimOp::getSpeculatability() {
916 auto constantIndex = getConstantIndex();
917 if (!constantIndex)
919
920 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
921 if (!rankedSourceType)
923
924 if (rankedSourceType.getRank() <= constantIndex)
926
928}
929
930void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
931 SetIntLatticeFn setResultRange) {
932 setResultRange(getResult(),
933 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
934}
935
936/// Return a map with key being elements in `vals` and data being number of
937/// occurences of it. Use std::map, since the `vals` here are strides and the
938/// dynamic stride value is the same as the tombstone value for
939/// `DenseMap<int64_t>`.
940static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
941 std::map<int64_t, unsigned> numOccurences;
942 for (auto val : vals)
943 numOccurences[val]++;
944 return numOccurences;
945}
946
947/// Returns the set of source dimensions that are dropped in a rank reduction.
948/// For each result dimension in order, matches the leftmost unmatched source
949/// dimension with the same size. Source dimensions not matched are dropped.
950///
951/// Example: memref<1x8x1x3> to memref<1x8x3>. Source sizes [1, 8, 1, 3], result
952/// [1, 8, 3]. Match result[0]=1 -> source dim 0, result[1]=8 -> source dim 1,
953/// result[2]=3 -> source dim 3. Source dim 2 is unmatched and dropped.
954static FailureOr<llvm::SmallBitVector>
956 MemRefType reducedType,
958 int64_t rankReduction = originalType.getRank() - reducedType.getRank();
959 if (rankReduction <= 0)
960 return llvm::SmallBitVector(originalType.getRank());
961
962 // Build source sizes from subview sizes (one per source dim).
963 SmallVector<int64_t> sourceSizes(originalType.getRank());
964 for (const auto &it : llvm::enumerate(sizes)) {
965 if (std::optional<int64_t> cst = getConstantIntValue(it.value()))
966 sourceSizes[it.index()] = *cst;
967 else
968 sourceSizes[it.index()] = ShapedType::kDynamic;
969 }
970
971 ArrayRef<int64_t> resultSizes = reducedType.getShape();
972 llvm::SmallBitVector usedSourceDims(originalType.getRank());
973 int64_t startJ = 0;
974 for (int64_t resultSize : resultSizes) {
975 bool matched = false;
976 for (int64_t j = startJ; j < originalType.getRank(); ++j) {
977 if (sourceSizes[j] == resultSize) {
978 usedSourceDims.set(j);
979 matched = true;
980 startJ = j + 1;
981 break;
982 }
983 }
984 if (!matched)
985 return failure();
986 }
987
988 llvm::SmallBitVector unusedDims(originalType.getRank());
989 for (int64_t i = 0; i < originalType.getRank(); ++i)
990 if (!usedSourceDims.test(i))
991 unusedDims.set(i);
992 return unusedDims;
993}
994
995/// Returns the set of source dimensions that are dropped in a rank reduction.
996/// A dimension is dropped if its stride is dropped; uses stride occurrence
997/// counting to disambiguate when multiple unit dims exist.
998///
999/// Example: memref<1x1x?xf32, strided<[?, 4, 1]>> to memref<1x4xf32,
1000/// strided<[4, 1]>>. Source strides [?, 4, 1], candidate [4, 1]. Dim 0 (stride
1001/// ?) can be dropped; dim 1 (stride 4) must be kept. Source dim 0 is dropped.
1002static FailureOr<llvm::SmallBitVector> computeMemRefRankReductionMaskByStrides(
1003 MemRefType originalType, MemRefType reducedType,
1004 ArrayRef<int64_t> originalStrides, ArrayRef<int64_t> candidateStrides,
1005 llvm::SmallBitVector unusedDims) {
1006 // Track the number of occurences of the strides in the original type
1007 // and the candidate type. For each unused dim that stride should not be
1008 // present in the candidate type. Note that there could be multiple dimensions
1009 // that have the same size. We dont need to exactly figure out which dim
1010 // corresponds to which stride, we just need to verify that the number of
1011 // reptitions of a stride in the original + number of unused dims with that
1012 // stride == number of repititions of a stride in the candidate.
1013 std::map<int64_t, unsigned> currUnaccountedStrides =
1014 getNumOccurences(originalStrides);
1015 std::map<int64_t, unsigned> candidateStridesNumOccurences =
1016 getNumOccurences(candidateStrides);
1017 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
1018 if (!unusedDims.test(dim))
1019 continue;
1020 int64_t originalStride = originalStrides[dim];
1021 if (currUnaccountedStrides[originalStride] >
1022 candidateStridesNumOccurences[originalStride]) {
1023 // This dim can be treated as dropped.
1024 currUnaccountedStrides[originalStride]--;
1025 continue;
1026 }
1027 if (currUnaccountedStrides[originalStride] ==
1028 candidateStridesNumOccurences[originalStride]) {
1029 // The stride for this is not dropped. Keep as is.
1030 unusedDims.reset(dim);
1031 continue;
1032 }
1033 if (currUnaccountedStrides[originalStride] <
1034 candidateStridesNumOccurences[originalStride]) {
1035 // This should never happen. Cant have a stride in the reduced rank type
1036 // that wasnt in the original one.
1037 return failure();
1038 }
1039 }
1040 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() !=
1041 originalType.getRank())
1042 return failure();
1043 return unusedDims;
1044}
1045
1046/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
1047/// to be a subset of `originalType` with some `1` entries erased, return the
1048/// set of indices that specifies which of the entries of `originalShape` are
1049/// dropped to obtain `reducedShape`.
1050/// This accounts for cases where there are multiple unit-dims, but only a
1051/// subset of those are dropped. For MemRefTypes these can be disambiguated
1052/// using the strides. If a dimension is dropped the stride must be dropped too.
1053static FailureOr<llvm::SmallBitVector>
1054computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
1055 ArrayRef<OpFoldResult> sizes) {
1056 llvm::SmallBitVector unusedDims(originalType.getRank());
1057 if (originalType.getRank() == reducedType.getRank())
1058 return unusedDims;
1059
1060 for (const auto &dim : llvm::enumerate(sizes))
1061 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
1062 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
1063 unusedDims.set(dim.index());
1064
1065 // Early exit for the case where the number of unused dims matches the number
1066 // of ranks reduced.
1067 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
1068 originalType.getRank())
1069 return unusedDims;
1070
1071 SmallVector<int64_t> originalStrides, candidateStrides;
1072 int64_t originalOffset, candidateOffset;
1073 if (failed(
1074 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
1075 failed(
1076 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
1077 return failure();
1078
1079 // Try stride-based first when we have meaningful static stride info
1080 // (preserves static strides). Fall back to position-based otherwise.
1081 auto hasNonTrivialStaticStride = [](ArrayRef<int64_t> strides) {
1082 // The innermost stride 1 is trivial for row-major and does not help
1083 // disambiguate.
1084 if (strides.size() <= 1)
1085 return false;
1086 return llvm::any_of(strides.drop_back(),
1087 [](int64_t s) { return !ShapedType::isDynamic(s); });
1088 };
1089 if (hasNonTrivialStaticStride(originalStrides) ||
1090 hasNonTrivialStaticStride(candidateStrides)) {
1091 FailureOr<llvm::SmallBitVector> strideBased =
1092 computeMemRefRankReductionMaskByStrides(originalType, reducedType,
1093 originalStrides,
1094 candidateStrides, unusedDims);
1095 if (succeeded(strideBased))
1096 return *strideBased;
1097 }
1098 return computeMemRefRankReductionMaskByPosition(originalType, reducedType,
1099 sizes);
1100}
1101
1102llvm::SmallBitVector SubViewOp::getDroppedDims() {
1103 MemRefType sourceType = getSourceType();
1104 MemRefType resultType = getType();
1105 FailureOr<llvm::SmallBitVector> unusedDims =
1106 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1107 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1108 return *unusedDims;
1109}
1110
1111OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1112 // All forms of folding require a known index.
1113 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1114 if (!index)
1115 return {};
1116
1117 // Folding for unranked types (UnrankedMemRefType) is not supported.
1118 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1119 if (!memrefType)
1120 return {};
1121
1122 // Out of bound indices produce undefined behavior but are still valid IR.
1123 // Don't choke on them.
1124 int64_t indexVal = index.getInt();
1125 if (indexVal < 0 || indexVal >= memrefType.getRank())
1126 return {};
1127
1128 // Fold if the shape extent along the given index is known.
1129 if (!memrefType.isDynamicDim(index.getInt())) {
1130 Builder builder(getContext());
1131 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1132 }
1133
1134 // The size at the given index is now known to be a dynamic size.
1135 unsigned unsignedIndex = index.getValue().getZExtValue();
1136
1137 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1138 Operation *definingOp = getSource().getDefiningOp();
1139
1140 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1141 return *(alloc.getDynamicSizes().begin() +
1142 memrefType.getDynamicDimIndex(unsignedIndex));
1143
1144 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1145 return *(alloca.getDynamicSizes().begin() +
1146 memrefType.getDynamicDimIndex(unsignedIndex));
1147
1148 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1149 return *(view.getDynamicSizes().begin() +
1150 memrefType.getDynamicDimIndex(unsignedIndex));
1151
1152 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1153 // The result dim is dynamic (the static case was handled above). Dropped
1154 // dims always have static size 1, so dynamic source sizes are never
1155 // dropped and map in order to the dynamic result dims. Find the k-th
1156 // dynamic source size, where k is the dynamic dim index of the result dim.
1157 unsigned dynamicResultDimIdx = memrefType.getDynamicDimIndex(unsignedIndex);
1158 unsigned dynamicIdx = 0;
1159 for (OpFoldResult size : subview.getMixedSizes()) {
1160 if (llvm::isa<Attribute>(size))
1161 continue;
1162 if (dynamicIdx == dynamicResultDimIdx)
1163 return size;
1164 dynamicIdx++;
1165 }
1166 return {};
1167 }
1168
1169 // dim(memrefcast) -> dim
1170 if (succeeded(foldMemRefCast(*this)))
1171 return getResult();
1172
1173 return {};
1174}
1175
1176namespace {
1177/// Fold dim of a memref reshape operation to a load into the reshape's shape
1178/// operand.
1179struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1180 using OpRewritePattern<DimOp>::OpRewritePattern;
1181
1182 LogicalResult matchAndRewrite(DimOp dim,
1183 PatternRewriter &rewriter) const override {
1184 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1185
1186 if (!reshape)
1187 return rewriter.notifyMatchFailure(
1188 dim, "Dim op is not defined by a reshape op.");
1189
1190 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1191 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1192 // cheaply check that either of the following conditions hold:
1193 // 1. dim.getIndex() is defined in the same block as reshape but before
1194 // reshape.
1195 // 2. dim.getIndex() is defined in a parent block of
1196 // reshape.
1197
1198 // Check condition 1
1199 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1200 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1201 if (reshape->isBeforeInBlock(definingOp)) {
1202 return rewriter.notifyMatchFailure(
1203 dim,
1204 "dim.getIndex is not defined before reshape in the same block.");
1205 }
1206 } // else dim.getIndex is a block argument to reshape->getBlock and
1207 // dominates reshape
1208 } // Check condition 2
1209 else if (dim->getBlock() != reshape->getBlock() &&
1210 !dim.getIndex().getParentRegion()->isProperAncestor(
1211 reshape->getParentRegion())) {
1212 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1213 // already know dim.getIndex() dominates reshape without calling
1214 // `isProperAncestor`
1215 return rewriter.notifyMatchFailure(
1216 dim, "dim.getIndex does not dominate reshape.");
1217 }
1218
1219 // Place the load directly after the reshape to ensure that the shape memref
1220 // was not mutated.
1221 rewriter.setInsertionPointAfter(reshape);
1222 Location loc = dim.getLoc();
1223 Value load =
1224 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1225 if (load.getType() != dim.getType())
1226 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1227 rewriter.replaceOp(dim, load);
1228 return success();
1229 }
1230};
1231
1232} // namespace
1233
1234void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1235 MLIRContext *context) {
1236 results.add<DimOfMemRefReshape>(context);
1237}
1238
1239// ---------------------------------------------------------------------------
1240// DmaStartOp
1241// ---------------------------------------------------------------------------
1242
1243void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1244 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1245 ValueRange destIndices, Value numElements,
1246 Value tagMemRef, ValueRange tagIndices, Value stride,
1247 Value elementsPerStride) {
1248 result.addOperands(srcMemRef);
1249 result.addOperands(srcIndices);
1250 result.addOperands(destMemRef);
1251 result.addOperands(destIndices);
1252 result.addOperands({numElements, tagMemRef});
1253 result.addOperands(tagIndices);
1254 if (stride)
1255 result.addOperands({stride, elementsPerStride});
1256}
1257
1258void DmaStartOp::print(OpAsmPrinter &p) {
1259 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1260 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1261 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1262 if (isStrided())
1263 p << ", " << getStride() << ", " << getNumElementsPerStride();
1264
1265 p.printOptionalAttrDict((*this)->getAttrs());
1266 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1267 << ", " << getTagMemRef().getType();
1268}
1269
1270// Parse DmaStartOp.
1271// Ex:
1272// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1273// %tag[%index], %stride, %num_elt_per_stride :
1274// : memref<3076 x f32, 0>,
1275// memref<1024 x f32, 2>,
1276// memref<1 x i32>
1277//
1278ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1279 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1280 SmallVector<OpAsmParser::UnresolvedOperand, 4> srcIndexInfos;
1281 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1282 SmallVector<OpAsmParser::UnresolvedOperand, 4> dstIndexInfos;
1283 OpAsmParser::UnresolvedOperand numElementsInfo;
1284 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1285 SmallVector<OpAsmParser::UnresolvedOperand, 4> tagIndexInfos;
1286 SmallVector<OpAsmParser::UnresolvedOperand, 2> strideInfo;
1287
1288 SmallVector<Type, 3> types;
1289 auto indexType = parser.getBuilder().getIndexType();
1290
1291 // Parse and resolve the following list of operands:
1292 // *) source memref followed by its indices (in square brackets).
1293 // *) destination memref followed by its indices (in square brackets).
1294 // *) dma size in KiB.
1295 if (parser.parseOperand(srcMemRefInfo) ||
1296 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1297 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1298 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1299 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1300 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1301 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1302 return failure();
1303
1304 // Parse optional stride and elements per stride.
1305 if (parser.parseTrailingOperandList(strideInfo))
1306 return failure();
1307
1308 bool isStrided = strideInfo.size() == 2;
1309 if (!strideInfo.empty() && !isStrided) {
1310 return parser.emitError(parser.getNameLoc(),
1311 "expected two stride related operands");
1312 }
1313
1314 if (parser.parseColonTypeList(types))
1315 return failure();
1316 if (types.size() != 3)
1317 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1318
1319 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1320 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1321 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1322 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1323 // size should be an index.
1324 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1325 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1326 // tag indices should be index.
1327 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1328 return failure();
1329
1330 if (isStrided) {
1331 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1332 return failure();
1333 }
1334
1335 return success();
1336}
1337
1338LogicalResult DmaStartOp::verify() {
1339 unsigned numOperands = getNumOperands();
1340
1341 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1342 // the number of elements.
1343 if (numOperands < 4)
1344 return emitOpError("expected at least 4 operands");
1345
1346 // Check types of operands. The order of these calls is important: the later
1347 // calls rely on some type properties to compute the operand position.
1348 // 1. Source memref.
1349 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1350 return emitOpError("expected source to be of memref type");
1351 if (numOperands < getSrcMemRefRank() + 4)
1352 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1353 << " operands";
1354 if (!getSrcIndices().empty() &&
1355 !llvm::all_of(getSrcIndices().getTypes(),
1356 [](Type t) { return t.isIndex(); }))
1357 return emitOpError("expected source indices to be of index type");
1358
1359 // 2. Destination memref.
1360 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1361 return emitOpError("expected destination to be of memref type");
1362 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1363 if (numOperands < numExpectedOperands)
1364 return emitOpError() << "expected at least " << numExpectedOperands
1365 << " operands";
1366 if (!getDstIndices().empty() &&
1367 !llvm::all_of(getDstIndices().getTypes(),
1368 [](Type t) { return t.isIndex(); }))
1369 return emitOpError("expected destination indices to be of index type");
1370
1371 // 3. Number of elements.
1372 if (!getNumElements().getType().isIndex())
1373 return emitOpError("expected num elements to be of index type");
1374
1375 // 4. Tag memref.
1376 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1377 return emitOpError("expected tag to be of memref type");
1378 numExpectedOperands += getTagMemRefRank();
1379 if (numOperands < numExpectedOperands)
1380 return emitOpError() << "expected at least " << numExpectedOperands
1381 << " operands";
1382 if (!getTagIndices().empty() &&
1383 !llvm::all_of(getTagIndices().getTypes(),
1384 [](Type t) { return t.isIndex(); }))
1385 return emitOpError("expected tag indices to be of index type");
1386
1387 // Optional stride-related operands must be either both present or both
1388 // absent.
1389 if (numOperands != numExpectedOperands &&
1390 numOperands != numExpectedOperands + 2)
1391 return emitOpError("incorrect number of operands");
1392
1393 // 5. Strides.
1394 if (isStrided()) {
1395 if (!getStride().getType().isIndex() ||
1396 !getNumElementsPerStride().getType().isIndex())
1397 return emitOpError(
1398 "expected stride and num elements per stride to be of type index");
1399 }
1400
1401 return success();
1402}
1403
1404LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1405 SmallVectorImpl<OpFoldResult> &results) {
1406 /// dma_start(memrefcast) -> dma_start
1407 return foldMemRefCast(*this);
1408}
1409
1410// ---------------------------------------------------------------------------
1411// DmaWaitOp
1412// ---------------------------------------------------------------------------
1413
1414LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1415 SmallVectorImpl<OpFoldResult> &results) {
1416 /// dma_wait(memrefcast) -> dma_wait
1417 return foldMemRefCast(*this);
1418}
1419
1420LogicalResult DmaWaitOp::verify() {
1421 // Check that the number of tag indices matches the tagMemRef rank.
1422 unsigned numTagIndices = getTagIndices().size();
1423 unsigned tagMemRefRank = getTagMemRefRank();
1424 if (numTagIndices != tagMemRefRank)
1425 return emitOpError() << "expected tagIndices to have the same number of "
1426 "elements as the tagMemRef rank, expected "
1427 << tagMemRefRank << ", but got " << numTagIndices;
1428 return success();
1429}
1430
1431//===----------------------------------------------------------------------===//
1432// ExtractAlignedPointerAsIndexOp
1433//===----------------------------------------------------------------------===//
1434
1435void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1436 function_ref<void(Value, StringRef)> setNameFn) {
1437 setNameFn(getResult(), "intptr");
1438}
1439
1440//===----------------------------------------------------------------------===//
1441// ExtractStridedMetadataOp
1442//===----------------------------------------------------------------------===//
1443
1444/// The number and type of the results are inferred from the
1445/// shape of the source.
1446LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1447 MLIRContext *context, std::optional<Location> location,
1448 ExtractStridedMetadataOp::Adaptor adaptor,
1449 SmallVectorImpl<Type> &inferredReturnTypes) {
1450 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1451 if (!sourceType)
1452 return failure();
1453
1454 unsigned sourceRank = sourceType.getRank();
1455 IndexType indexType = IndexType::get(context);
1456 auto memrefType =
1457 MemRefType::get({}, sourceType.getElementType(),
1458 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1459 // Base.
1460 inferredReturnTypes.push_back(memrefType);
1461 // Offset.
1462 inferredReturnTypes.push_back(indexType);
1463 // Sizes and strides.
1464 for (unsigned i = 0; i < sourceRank * 2; ++i)
1465 inferredReturnTypes.push_back(indexType);
1466 return success();
1467}
1468
1469void ExtractStridedMetadataOp::getAsmResultNames(
1470 function_ref<void(Value, StringRef)> setNameFn) {
1471 setNameFn(getBaseBuffer(), "base_buffer");
1472 setNameFn(getOffset(), "offset");
1473 // For multi-result to work properly with pretty names and packed syntax `x:3`
1474 // we can only give a pretty name to the first value in the pack.
1475 if (!getSizes().empty()) {
1476 setNameFn(getSizes().front(), "sizes");
1477 setNameFn(getStrides().front(), "strides");
1478 }
1479}
1480
1481/// Helper function to perform the replacement of all constant uses of `values`
1482/// by a materialized constant extracted from `maybeConstants`.
1483/// `values` and `maybeConstants` are expected to have the same size.
1484template <typename Container>
1485static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1486 Container values,
1487 ArrayRef<OpFoldResult> maybeConstants) {
1488 assert(values.size() == maybeConstants.size() &&
1489 " expected values and maybeConstants of the same size");
1490 bool atLeastOneReplacement = false;
1491 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1492 // Don't materialize a constant if there are no uses: this would indice
1493 // infinite loops in the driver.
1494 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1495 continue;
1496 assert(isa<Attribute>(maybeConstant) &&
1497 "The constified value should be either unchanged (i.e., == result) "
1498 "or a constant");
1500 rewriter, loc,
1501 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1502 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1503 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1504 // yet.
1505 op->replaceUsesOfWith(result, constantVal);
1506 atLeastOneReplacement = true;
1507 }
1508 }
1509 return atLeastOneReplacement;
1510}
1511
1512LogicalResult
1513ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1514 SmallVectorImpl<OpFoldResult> &results) {
1515 OpBuilder builder(*this);
1516
1517 bool atLeastOneReplacement = replaceConstantUsesOf(
1518 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1519 getConstifiedMixedOffset());
1520 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1521 getConstifiedMixedSizes());
1522 atLeastOneReplacement |= replaceConstantUsesOf(
1523 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1524
1525 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1526 if (auto prev = getSource().getDefiningOp<CastOp>())
1527 if (isa<MemRefType>(prev.getSource().getType())) {
1528 getSourceMutable().assign(prev.getSource());
1529 atLeastOneReplacement = true;
1530 }
1531
1532 return success(atLeastOneReplacement);
1533}
1534
1535SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1536 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1537 constifyIndexValues(values, getSource().getType().getShape());
1538 return values;
1539}
1540
1541SmallVector<OpFoldResult>
1542ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1543 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1544 SmallVector<int64_t> staticValues;
1545 int64_t unused;
1546 LogicalResult status =
1547 getSource().getType().getStridesAndOffset(staticValues, unused);
1548 (void)status;
1549 assert(succeeded(status) && "could not get strides from type");
1550 constifyIndexValues(values, staticValues);
1551 return values;
1552}
1553
1554OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1555 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1556 SmallVector<OpFoldResult> values(1, offsetOfr);
1557 SmallVector<int64_t> staticValues, unused;
1558 int64_t offset;
1559 LogicalResult status =
1560 getSource().getType().getStridesAndOffset(unused, offset);
1561 (void)status;
1562 assert(succeeded(status) && "could not get offset from type");
1563 staticValues.push_back(offset);
1564 constifyIndexValues(values, staticValues);
1565 return values[0];
1566}
1567
1568//===----------------------------------------------------------------------===//
1569// GenericAtomicRMWOp
1570//===----------------------------------------------------------------------===//
1571
1572void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1573 Value memref, ValueRange ivs) {
1574 OpBuilder::InsertionGuard g(builder);
1575 result.addOperands(memref);
1576 result.addOperands(ivs);
1577
1578 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1579 Type elementType = memrefType.getElementType();
1580 result.addTypes(elementType);
1581
1582 Region *bodyRegion = result.addRegion();
1583 builder.createBlock(bodyRegion);
1584 bodyRegion->addArgument(elementType, memref.getLoc());
1585 }
1586}
1587
1588LogicalResult GenericAtomicRMWOp::verify() {
1589 auto &body = getRegion();
1590 if (body.getNumArguments() != 1)
1591 return emitOpError("expected single number of entry block arguments");
1592
1593 if (getResult().getType() != body.getArgument(0).getType())
1594 return emitOpError("expected block argument of the same type result type");
1595
1596 bool hasSideEffects =
1597 body.walk([&](Operation *nestedOp) {
1598 if (isMemoryEffectFree(nestedOp))
1599 return WalkResult::advance();
1600 nestedOp->emitError(
1601 "body of 'memref.generic_atomic_rmw' should contain "
1602 "only operations with no side effects");
1603 return WalkResult::interrupt();
1604 })
1605 .wasInterrupted();
1606 return hasSideEffects ? failure() : success();
1607}
1608
1609ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1610 OperationState &result) {
1611 OpAsmParser::UnresolvedOperand memref;
1612 Type memrefType;
1613 SmallVector<OpAsmParser::UnresolvedOperand, 4> ivs;
1614
1615 Type indexType = parser.getBuilder().getIndexType();
1616 if (parser.parseOperand(memref) ||
1618 parser.parseColonType(memrefType) ||
1619 parser.resolveOperand(memref, memrefType, result.operands) ||
1620 parser.resolveOperands(ivs, indexType, result.operands))
1621 return failure();
1622
1623 Region *body = result.addRegion();
1624 if (parser.parseRegion(*body, {}) ||
1625 parser.parseOptionalAttrDict(result.attributes))
1626 return failure();
1627 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1628 return success();
1629}
1630
1631void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1632 p << ' ' << getMemref() << "[" << getIndices()
1633 << "] : " << getMemref().getType() << ' ';
1634 p.printRegion(getRegion());
1635 p.printOptionalAttrDict((*this)->getAttrs());
1636}
1637
1638//===----------------------------------------------------------------------===//
1639// AtomicYieldOp
1640//===----------------------------------------------------------------------===//
1641
1642LogicalResult AtomicYieldOp::verify() {
1643 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1644 Type resultType = getResult().getType();
1645 if (parentType != resultType)
1646 return emitOpError() << "types mismatch between yield op: " << resultType
1647 << " and its parent: " << parentType;
1648 return success();
1649}
1650
1651//===----------------------------------------------------------------------===//
1652// GlobalOp
1653//===----------------------------------------------------------------------===//
1654
1656 TypeAttr type,
1657 Attribute initialValue) {
1658 p << type;
1659 if (!op.isExternal()) {
1660 p << " = ";
1661 if (op.isUninitialized())
1662 p << "uninitialized";
1663 else
1664 p.printAttributeWithoutType(initialValue);
1665 }
1666}
1667
1668static ParseResult
1670 Attribute &initialValue) {
1671 Type type;
1672 if (parser.parseType(type))
1673 return failure();
1674
1675 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1676 if (!memrefType || !memrefType.hasStaticShape())
1677 return parser.emitError(parser.getNameLoc())
1678 << "type should be static shaped memref, but got " << type;
1679 typeAttr = TypeAttr::get(type);
1680
1681 if (parser.parseOptionalEqual())
1682 return success();
1683
1684 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1685 initialValue = UnitAttr::get(parser.getContext());
1686 return success();
1687 }
1688
1689 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1690 if (parser.parseAttribute(initialValue, tensorType))
1691 return failure();
1692 if (!llvm::isa<ElementsAttr>(initialValue))
1693 return parser.emitError(parser.getNameLoc())
1694 << "initial value should be a unit or elements attribute";
1695 return success();
1696}
1697
1698LogicalResult GlobalOp::verify() {
1699 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1700 if (!memrefType || !memrefType.hasStaticShape())
1701 return emitOpError("type should be static shaped memref, but got ")
1702 << getType();
1703
1704 // Verify that the initial value, if present, is either a unit attribute or
1705 // an elements attribute.
1706 if (getInitialValue().has_value()) {
1707 Attribute initValue = getInitialValue().value();
1708 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1709 return emitOpError("initial value should be a unit or elements "
1710 "attribute, but got ")
1711 << initValue;
1712
1713 // Check that the type of the initial value is compatible with the type of
1714 // the global variable.
1715 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1716 // Check the element types match.
1717 auto initElementType =
1718 cast<TensorType>(elementsAttr.getType()).getElementType();
1719 auto memrefElementType = memrefType.getElementType();
1720
1721 if (initElementType != memrefElementType)
1722 return emitOpError("initial value element expected to be of type ")
1723 << memrefElementType << ", but was of type " << initElementType;
1724
1725 // Check the shapes match, given that memref globals can only produce
1726 // statically shaped memrefs and elements literal type must have a static
1727 // shape we can assume both types are shaped.
1728 auto initShape = elementsAttr.getShapedType().getShape();
1729 auto memrefShape = memrefType.getShape();
1730 if (initShape != memrefShape)
1731 return emitOpError("initial value shape expected to be ")
1732 << memrefShape << " but was " << initShape;
1733 }
1734 }
1735
1736 // TODO: verify visibility for declarations.
1737 return success();
1738}
1739
1740ElementsAttr GlobalOp::getConstantInitValue() {
1741 auto initVal = getInitialValue();
1742 if (getConstant() && initVal.has_value())
1743 return llvm::cast<ElementsAttr>(initVal.value());
1744 return {};
1745}
1746
1747//===----------------------------------------------------------------------===//
1748// GetGlobalOp
1749//===----------------------------------------------------------------------===//
1750
1751LogicalResult
1752GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1753 // Verify that the result type is same as the type of the referenced
1754 // memref.global op.
1755 auto global =
1756 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1757 if (!global)
1758 return emitOpError("'")
1759 << getName() << "' does not reference a valid global memref";
1760
1761 Type resultType = getResult().getType();
1762 if (global.getType() != resultType)
1763 return emitOpError("result type ")
1764 << resultType << " does not match type " << global.getType()
1765 << " of the global memref @" << getName();
1766 return success();
1767}
1768
1769//===----------------------------------------------------------------------===//
1770// LoadOp
1771//===----------------------------------------------------------------------===//
1772
1773LogicalResult LoadOp::verify() {
1774 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1775 return emitOpError("incorrect number of indices for load, expected ")
1776 << getMemRefType().getRank() << " but got " << getIndices().size();
1777 }
1778 return success();
1779}
1780
1781OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1782 /// load(memrefcast) -> load
1783 if (succeeded(foldMemRefCast(*this)))
1784 return getResult();
1785
1786 // Fold load from a global constant memref.
1787 auto getGlobalOp = getMemref().getDefiningOp<memref::GetGlobalOp>();
1788 if (!getGlobalOp)
1789 return {};
1790
1791 // Get to the memref.global defining the symbol.
1793 getGlobalOp, getGlobalOp.getNameAttr());
1794 if (!global)
1795 return {};
1796 // If it's a splat constant, we can fold irrespective of indices.
1797 auto splatAttr =
1798 dyn_cast_or_null<SplatElementsAttr>(global.getConstantInitValue());
1799 if (!splatAttr)
1800 return {};
1801
1802 return splatAttr.getSplatValue<Attribute>();
1803}
1804
1805FailureOr<std::optional<SmallVector<Value>>>
1806LoadOp::bubbleDownCasts(OpBuilder &builder) {
1808 getResult());
1809}
1810
1811//===----------------------------------------------------------------------===//
1812// MemorySpaceCastOp
1813//===----------------------------------------------------------------------===//
1814
1815void MemorySpaceCastOp::getAsmResultNames(
1816 function_ref<void(Value, StringRef)> setNameFn) {
1817 setNameFn(getResult(), "memspacecast");
1818}
1819
1820bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1821 if (inputs.size() != 1 || outputs.size() != 1)
1822 return false;
1823 Type a = inputs.front(), b = outputs.front();
1824 auto aT = llvm::dyn_cast<MemRefType>(a);
1825 auto bT = llvm::dyn_cast<MemRefType>(b);
1826
1827 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1828 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1829
1830 if (aT && bT) {
1831 if (aT.getElementType() != bT.getElementType())
1832 return false;
1833 if (aT.getLayout() != bT.getLayout())
1834 return false;
1835 if (aT.getShape() != bT.getShape())
1836 return false;
1837 return true;
1838 }
1839 if (uaT && ubT) {
1840 return uaT.getElementType() == ubT.getElementType();
1841 }
1842 return false;
1843}
1844
1845OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1846 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1847 // t2)
1848 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1849 getSourceMutable().assign(parentCast.getSource());
1850 return getResult();
1851 }
1852 return Value{};
1853}
1854
1855TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1856 return getSource();
1857}
1858
1859TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1860 return getDest();
1861}
1862
1863bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1864 PtrLikeTypeInterface src) {
1865 return isa<BaseMemRefType>(tgt) &&
1866 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1867}
1868
1869MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1870 OpBuilder &b, PtrLikeTypeInterface tgt,
1872 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1873 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1874}
1875
1876/// The only cast we recognize as promotable is to the generic space.
1877bool MemorySpaceCastOp::isSourcePromotable() {
1878 return getDest().getType().getMemorySpace() == nullptr;
1879}
1880
1881//===----------------------------------------------------------------------===//
1882// PrefetchOp
1883//===----------------------------------------------------------------------===//
1884
1885void PrefetchOp::print(OpAsmPrinter &p) {
1886 p << " " << getMemref() << '[';
1888 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1889 p << ", locality<" << getLocalityHint();
1890 p << ">, " << (getIsDataCache() ? "data" : "instr");
1892 (*this)->getAttrs(),
1893 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1894 p << " : " << getMemRefType();
1895}
1896
1897ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1898 OpAsmParser::UnresolvedOperand memrefInfo;
1899 SmallVector<OpAsmParser::UnresolvedOperand, 4> indexInfo;
1900 IntegerAttr localityHint;
1901 MemRefType type;
1902 StringRef readOrWrite, cacheType;
1903
1904 auto indexTy = parser.getBuilder().getIndexType();
1905 auto i32Type = parser.getBuilder().getIntegerType(32);
1906 if (parser.parseOperand(memrefInfo) ||
1908 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1909 parser.parseComma() || parser.parseKeyword("locality") ||
1910 parser.parseLess() ||
1911 parser.parseAttribute(localityHint, i32Type, "localityHint",
1912 result.attributes) ||
1913 parser.parseGreater() || parser.parseComma() ||
1914 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1915 parser.resolveOperand(memrefInfo, type, result.operands) ||
1916 parser.resolveOperands(indexInfo, indexTy, result.operands))
1917 return failure();
1918
1919 if (readOrWrite != "read" && readOrWrite != "write")
1920 return parser.emitError(parser.getNameLoc(),
1921 "rw specifier has to be 'read' or 'write'");
1922 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1923 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1924
1925 if (cacheType != "data" && cacheType != "instr")
1926 return parser.emitError(parser.getNameLoc(),
1927 "cache type has to be 'data' or 'instr'");
1928
1929 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1930 parser.getBuilder().getBoolAttr(cacheType == "data"));
1931
1932 return success();
1933}
1934
1935LogicalResult PrefetchOp::verify() {
1936 if (getNumOperands() != 1 + getMemRefType().getRank())
1937 return emitOpError("too few indices");
1938
1939 return success();
1940}
1941
1942LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1943 SmallVectorImpl<OpFoldResult> &results) {
1944 // prefetch(memrefcast) -> prefetch
1945 return foldMemRefCast(*this);
1946}
1947
1948//===----------------------------------------------------------------------===//
1949// RankOp
1950//===----------------------------------------------------------------------===//
1951
1952OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1953 // Constant fold rank when the rank of the operand is known.
1954 auto type = getOperand().getType();
1955 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1956 if (shapedType && shapedType.hasRank())
1957 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1958 return IntegerAttr();
1959}
1960
1961//===----------------------------------------------------------------------===//
1962// ReinterpretCastOp
1963//===----------------------------------------------------------------------===//
1964
1965void ReinterpretCastOp::getAsmResultNames(
1966 function_ref<void(Value, StringRef)> setNameFn) {
1967 setNameFn(getResult(), "reinterpret_cast");
1968}
1969
1970/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1971/// `staticSizes` and `staticStrides` are automatically filled with
1972/// source-memref-rank sentinel values that encode dynamic entries.
1973void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1974 MemRefType resultType, Value source,
1975 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1976 ArrayRef<OpFoldResult> strides,
1977 ArrayRef<NamedAttribute> attrs) {
1978 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1979 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1980 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1981 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1982 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1983 result.addAttributes(attrs);
1984 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1985 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1986 b.getDenseI64ArrayAttr(staticSizes),
1987 b.getDenseI64ArrayAttr(staticStrides));
1988}
1989
1990void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1991 Value source, OpFoldResult offset,
1992 ArrayRef<OpFoldResult> sizes,
1993 ArrayRef<OpFoldResult> strides,
1994 ArrayRef<NamedAttribute> attrs) {
1995 auto sourceType = cast<BaseMemRefType>(source.getType());
1996 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1997 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1998 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1999 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2000 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2001 auto stridedLayout = StridedLayoutAttr::get(
2002 b.getContext(), staticOffsets.front(), staticStrides);
2003 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
2004 stridedLayout, sourceType.getMemorySpace());
2005 build(b, result, resultType, source, offset, sizes, strides, attrs);
2006}
2007
2008void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2009 MemRefType resultType, Value source,
2010 int64_t offset, ArrayRef<int64_t> sizes,
2011 ArrayRef<int64_t> strides,
2012 ArrayRef<NamedAttribute> attrs) {
2013 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
2014 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
2015 SmallVector<OpFoldResult> strideValues =
2016 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
2017 return b.getI64IntegerAttr(v);
2018 });
2019 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
2020 strideValues, attrs);
2021}
2022
2023void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
2024 MemRefType resultType, Value source, Value offset,
2025 ValueRange sizes, ValueRange strides,
2026 ArrayRef<NamedAttribute> attrs) {
2027 SmallVector<OpFoldResult> sizeValues =
2028 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
2029 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
2030 strides, [](Value v) -> OpFoldResult { return v; });
2031 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
2032}
2033
2034// TODO: ponder whether we want to allow missing trailing sizes/strides that are
2035// completed automatically, like we have for subview and extract_slice.
2036LogicalResult ReinterpretCastOp::verify() {
2037 // The source and result memrefs should be in the same memory space.
2038 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
2039 auto resultType = llvm::cast<MemRefType>(getType());
2040 if (srcType.getMemorySpace() != resultType.getMemorySpace())
2041 return emitError("different memory spaces specified for source type ")
2042 << srcType << " and result memref type " << resultType;
2043 if (failed(verifyElementTypesMatch(*this, srcType, resultType, "source",
2044 "result")))
2045 return failure();
2046
2047 // Match sizes in result memref type and in static_sizes attribute.
2048 for (auto [idx, resultSize, expectedSize] :
2049 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
2050 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
2051 return emitError("expected result type with size = ")
2052 << (ShapedType::isDynamic(expectedSize)
2053 ? std::string("dynamic")
2054 : std::to_string(expectedSize))
2055 << " instead of " << resultSize << " in dim = " << idx;
2056 }
2057
2058 // Match offset and strides in static_offset and static_strides attributes. If
2059 // result memref type has no affine map specified, this will assume an
2060 // identity layout.
2061 int64_t resultOffset;
2062 SmallVector<int64_t, 4> resultStrides;
2063 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
2064 return emitError("expected result type to have strided layout but found ")
2065 << resultType;
2066
2067 // Match offset in result memref type and in static_offsets attribute.
2068 int64_t expectedOffset = getStaticOffsets().front();
2069 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
2070 return emitError("expected result type with offset = ")
2071 << (ShapedType::isDynamic(expectedOffset)
2072 ? std::string("dynamic")
2073 : std::to_string(expectedOffset))
2074 << " instead of " << resultOffset;
2075
2076 // Match strides in result memref type and in static_strides attribute.
2077 for (auto [idx, resultStride, expectedStride] :
2078 llvm::enumerate(resultStrides, getStaticStrides())) {
2079 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
2080 return emitError("expected result type with stride = ")
2081 << (ShapedType::isDynamic(expectedStride)
2082 ? std::string("dynamic")
2083 : std::to_string(expectedStride))
2084 << " instead of " << resultStride << " in dim = " << idx;
2085 }
2086
2087 return success();
2088}
2089
2090OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
2091 Value src = getSource();
2092 auto getPrevSrc = [&]() -> Value {
2093 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
2094 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
2095 return prev.getSource();
2096
2097 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
2098 if (auto prev = src.getDefiningOp<CastOp>())
2099 return prev.getSource();
2100
2101 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2102 // are 0.
2103 if (auto prev = src.getDefiningOp<SubViewOp>())
2104 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2105 return prev.getSource();
2106
2107 return nullptr;
2108 };
2109
2110 if (auto prevSrc = getPrevSrc()) {
2111 getSourceMutable().assign(prevSrc);
2112 return getResult();
2113 }
2114
2115 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2116 if (ShapedType::isStaticShape(getType().getShape()) &&
2117 src.getType() == getType() && getStaticOffsets().front() == 0) {
2118 return src;
2119 }
2120
2121 return nullptr;
2122}
2123
2124SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2125 SmallVector<OpFoldResult> values = getMixedSizes();
2127 return values;
2128}
2129
2130SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2131 SmallVector<OpFoldResult> values = getMixedStrides();
2132 SmallVector<int64_t> staticValues;
2133 int64_t unused;
2134 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2135 (void)status;
2136 assert(succeeded(status) && "could not get strides from type");
2137 constifyIndexValues(values, staticValues);
2138 return values;
2139}
2140
2141OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2142 SmallVector<OpFoldResult> values = getMixedOffsets();
2143 assert(values.size() == 1 &&
2144 "reinterpret_cast must have one and only one offset");
2145 SmallVector<int64_t> staticValues, unused;
2146 int64_t offset;
2147 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2148 (void)status;
2149 assert(succeeded(status) && "could not get offset from type");
2150 staticValues.push_back(offset);
2151 constifyIndexValues(values, staticValues);
2152 return values[0];
2153}
2154
2155namespace {
2156/// Replace the sequence:
2157/// ```
2158/// base, offset, sizes, strides = extract_strided_metadata src
2159/// dst = reinterpret_cast base to offset, sizes, strides
2160/// ```
2161/// With
2162///
2163/// ```
2164/// dst = memref.cast src
2165/// ```
2166///
2167/// Note: The cast operation is only inserted when the type of dst and src
2168/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2169///
2170/// This pattern also matches when the offset, sizes, and strides don't come
2171/// directly from the `extract_strided_metadata`'s results but it can be
2172/// statically proven that they would hold the same values.
2173///
2174/// For instance, the following sequence would be replaced:
2175/// ```
2176/// base, offset, sizes, strides =
2177/// extract_strided_metadata memref : memref<3x4xty>
2178/// dst = reinterpret_cast base to 0, [3, 4], strides
2179/// ```
2180/// Because we know (thanks to the type of the input memref) that variable
2181/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2182///
2183/// Similarly, the following sequence would be replaced:
2184/// ```
2185/// c0 = arith.constant 0
2186/// c4 = arith.constant 4
2187/// base, offset, sizes, strides =
2188/// extract_strided_metadata memref : memref<3x4xty>
2189/// dst = reinterpret_cast base to c0, [3, c4], strides
2190/// ```
2191/// Because we know that `offset`and `c0` will hold 0
2192/// and `c4` will hold 4.
2193///
2194/// If the pattern above does not match, the input of the
2195/// extract_strided_metadata is always folded into the input of the
2196/// reinterpret_cast operator. This allows for dead code elimination to get rid
2197/// of the extract_strided_metadata in some cases.
2198struct ReinterpretCastOpExtractStridedMetadataFolder
2199 : public OpRewritePattern<ReinterpretCastOp> {
2200public:
2201 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2202
2203 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2204 PatternRewriter &rewriter) const override {
2205 auto extractStridedMetadata =
2206 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2207 if (!extractStridedMetadata)
2208 return failure();
2209
2210 // Check if the reinterpret cast reconstructs a memref with the exact same
2211 // properties as the extract strided metadata.
2212 auto isReinterpretCastNoop = [&]() -> bool {
2213 // First, check that the strides are the same.
2214 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2215 op.getConstifiedMixedStrides()))
2216 return false;
2217
2218 // Second, check the sizes.
2219 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2220 op.getConstifiedMixedSizes()))
2221 return false;
2222
2223 // Finally, check the offset.
2224 assert(op.getMixedOffsets().size() == 1 &&
2225 "reinterpret_cast with more than one offset should have been "
2226 "rejected by the verifier");
2227 return extractStridedMetadata.getConstifiedMixedOffset() ==
2228 op.getConstifiedMixedOffset();
2229 };
2230
2231 if (!isReinterpretCastNoop()) {
2232 // If the extract_strided_metadata / reinterpret_cast pair can't be
2233 // completely folded, then we could fold the input of the
2234 // extract_strided_metadata into the input of the reinterpret_cast
2235 // input. For some cases (e.g., static dimensions) the
2236 // the extract_strided_metadata is eliminated by dead code elimination.
2237 //
2238 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2239 //
2240 // We can always fold the input of a extract_strided_metadata operator
2241 // to the input of a reinterpret_cast operator, because they point to
2242 // the same memory. Note that the reinterpret_cast does not use the
2243 // layout of its input memref, only its base memory pointer which is
2244 // the same as the base pointer returned by the extract_strided_metadata
2245 // operator and the base pointer of the extract_strided_metadata memref
2246 // input.
2247 rewriter.modifyOpInPlace(op, [&]() {
2248 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2249 });
2250 return success();
2251 }
2252
2253 // At this point, we know that the back and forth between extract strided
2254 // metadata and reinterpret cast is a noop. However, the final type of the
2255 // reinterpret cast may not be exactly the same as the original memref.
2256 // E.g., it could be changing a dimension from static to dynamic. Check that
2257 // here and add a cast if necessary.
2258 Type srcTy = extractStridedMetadata.getSource().getType();
2259 if (srcTy == op.getResult().getType())
2260 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2261 else
2262 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2263 extractStridedMetadata.getSource());
2264
2265 return success();
2266 }
2267};
2268
2269struct ReinterpretCastOpConstantFolder
2270 : public OpRewritePattern<ReinterpretCastOp> {
2271public:
2272 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2273
2274 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2275 PatternRewriter &rewriter) const override {
2276 unsigned srcStaticCount = llvm::count_if(
2277 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2278 op.getMixedStrides()),
2279 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2280
2281 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2282 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2283 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2284
2285 // TODO: Using counting comparison instead of direct comparison because
2286 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2287 // IntegerAttrs, while constifyIndexValues (and therefore
2288 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2289 if (srcStaticCount ==
2290 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2291 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2292 return failure();
2293
2294 auto newReinterpretCast = ReinterpretCastOp::create(
2295 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2296
2297 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2298 return success();
2299 }
2300};
2301} // namespace
2302
2303void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2304 MLIRContext *context) {
2305 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2306 ReinterpretCastOpConstantFolder>(context);
2307}
2308
2309FailureOr<std::optional<SmallVector<Value>>>
2310ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2311 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2312}
2313
2314//===----------------------------------------------------------------------===//
2315// Reassociative reshape ops
2316//===----------------------------------------------------------------------===//
2317
2318void CollapseShapeOp::getAsmResultNames(
2319 function_ref<void(Value, StringRef)> setNameFn) {
2320 setNameFn(getResult(), "collapse_shape");
2321}
2322
2323void ExpandShapeOp::getAsmResultNames(
2324 function_ref<void(Value, StringRef)> setNameFn) {
2325 setNameFn(getResult(), "expand_shape");
2326}
2327
2328LogicalResult ExpandShapeOp::reifyResultShapes(
2329 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2330 reifiedResultShapes = {
2331 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2332 return success();
2333}
2334
2335/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2336/// result and operand. Layout maps are verified separately.
2337///
2338/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2339/// allowed in a reassocation group.
2340static LogicalResult
2342 ArrayRef<int64_t> expandedShape,
2343 ArrayRef<ReassociationIndices> reassociation,
2344 bool allowMultipleDynamicDimsPerGroup) {
2345 // There must be one reassociation group per collapsed dimension.
2346 if (collapsedShape.size() != reassociation.size())
2347 return op->emitOpError("invalid number of reassociation groups: found ")
2348 << reassociation.size() << ", expected " << collapsedShape.size();
2349
2350 // The next expected expanded dimension index (while iterating over
2351 // reassociation indices).
2352 int64_t nextDim = 0;
2353 for (const auto &it : llvm::enumerate(reassociation)) {
2354 ReassociationIndices group = it.value();
2355 int64_t collapsedDim = it.index();
2356
2357 bool foundDynamic = false;
2358 for (int64_t expandedDim : group) {
2359 if (expandedDim != nextDim++)
2360 return op->emitOpError("reassociation indices must be contiguous");
2361
2362 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2363 return op->emitOpError("reassociation index ")
2364 << expandedDim << " is out of bounds";
2365
2366 // Check if there are multiple dynamic dims in a reassociation group.
2367 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2368 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2369 return op->emitOpError(
2370 "at most one dimension in a reassociation group may be dynamic");
2371 foundDynamic = true;
2372 }
2373 }
2374
2375 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2376 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2377 return op->emitOpError("collapsed dim (")
2378 << collapsedDim
2379 << ") must be dynamic if and only if reassociation group is "
2380 "dynamic";
2381
2382 // If all dims in the reassociation group are static, the size of the
2383 // collapsed dim can be verified.
2384 if (!foundDynamic) {
2385 int64_t groupSize = 1;
2386 for (int64_t expandedDim : group)
2387 groupSize *= expandedShape[expandedDim];
2388 if (groupSize != collapsedShape[collapsedDim])
2389 return op->emitOpError("collapsed dim size (")
2390 << collapsedShape[collapsedDim]
2391 << ") must equal reassociation group size (" << groupSize << ")";
2392 }
2393 }
2394
2395 if (collapsedShape.empty()) {
2396 // Rank 0: All expanded dimensions must be 1.
2397 for (int64_t d : expandedShape)
2398 if (d != 1)
2399 return op->emitOpError(
2400 "rank 0 memrefs can only be extended/collapsed with/from ones");
2401 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2402 // Rank >= 1: Number of dimensions among all reassociation groups must match
2403 // the result memref rank.
2404 return op->emitOpError("expanded rank (")
2405 << expandedShape.size()
2406 << ") inconsistent with number of reassociation indices (" << nextDim
2407 << ")";
2408 }
2409
2410 return success();
2411}
2412
2413SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2414 return getSymbolLessAffineMaps(getReassociationExprs());
2415}
2416
2417SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2419 getReassociationIndices());
2420}
2421
2422SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2423 return getSymbolLessAffineMaps(getReassociationExprs());
2424}
2425
2426SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2428 getReassociationIndices());
2429}
2430
2431/// Compute the layout map after expanding a given source MemRef type with the
2432/// specified reassociation indices.
2433static FailureOr<StridedLayoutAttr>
2434computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2435 ArrayRef<ReassociationIndices> reassociation) {
2436 int64_t srcOffset;
2437 SmallVector<int64_t> srcStrides;
2438 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2439 return failure();
2440 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2441
2442 // 1-1 mapping between srcStrides and reassociation packs.
2443 // Each srcStride starts with the given value and gets expanded according to
2444 // the proper entries in resultShape.
2445 // Example:
2446 // srcStrides = [10000, 1 , 100 ],
2447 // reassociations = [ [0], [1], [2, 3, 4]],
2448 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2449 // -> For the purpose of stride calculation, the useful sizes are:
2450 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2451 // resultStrides = [10000, 1, 600, 200, 100]
2452 // Note that a stride does not get expanded along the first entry of each
2453 // shape pack.
2454 SmallVector<int64_t> reverseResultStrides;
2455 reverseResultStrides.reserve(resultShape.size());
2456 unsigned shapeIndex = resultShape.size() - 1;
2457 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2458 ReassociationIndices reassoc = std::get<0>(it);
2459 int64_t currentStrideToExpand = std::get<1>(it);
2460 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2461 reverseResultStrides.push_back(currentStrideToExpand);
2462 currentStrideToExpand =
2463 (SaturatedInteger::wrap(currentStrideToExpand) *
2464 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2465 .asInteger();
2466 }
2467 }
2468 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2469 resultStrides.resize(resultShape.size(), 1);
2470 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2471}
2472
2473FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2474 MemRefType srcType, ArrayRef<int64_t> resultShape,
2475 ArrayRef<ReassociationIndices> reassociation) {
2476 if (srcType.getLayout().isIdentity()) {
2477 // If the source is contiguous (i.e., no layout map specified), so is the
2478 // result.
2479 MemRefLayoutAttrInterface layout;
2480 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2481 srcType.getMemorySpace());
2482 }
2483
2484 // Source may not be contiguous. Compute the layout map.
2485 FailureOr<StridedLayoutAttr> computedLayout =
2486 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2487 if (failed(computedLayout))
2488 return failure();
2489 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2490 srcType.getMemorySpace());
2491}
2492
2493FailureOr<SmallVector<OpFoldResult>>
2494ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2495 MemRefType expandedType,
2496 ArrayRef<ReassociationIndices> reassociation,
2497 ArrayRef<OpFoldResult> inputShape) {
2498 std::optional<SmallVector<OpFoldResult>> outputShape =
2499 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2500 inputShape);
2501 if (!outputShape)
2502 return failure();
2503 return *outputShape;
2504}
2505
2506void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2507 Type resultType, Value src,
2508 ArrayRef<ReassociationIndices> reassociation,
2509 ArrayRef<OpFoldResult> outputShape) {
2510 auto [staticOutputShape, dynamicOutputShape] =
2511 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2512 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2513 getReassociationIndicesAttribute(builder, reassociation),
2514 dynamicOutputShape, staticOutputShape);
2515}
2516
2517void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2518 Type resultType, Value src,
2519 ArrayRef<ReassociationIndices> reassociation) {
2520 SmallVector<OpFoldResult> inputShape =
2521 getMixedSizes(builder, result.location, src);
2522 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2523 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2524 builder, result.location, memrefResultTy, reassociation, inputShape);
2525 // Failure of this assertion usually indicates presence of multiple
2526 // dynamic dimensions in the same reassociation group.
2527 assert(succeeded(outputShape) && "unable to infer output shape");
2528 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2529}
2530
2531void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2532 ArrayRef<int64_t> resultShape, Value src,
2533 ArrayRef<ReassociationIndices> reassociation) {
2534 // Only ranked memref source values are supported.
2535 auto srcType = llvm::cast<MemRefType>(src.getType());
2536 FailureOr<MemRefType> resultType =
2537 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2538 // Failure of this assertion usually indicates a problem with the source
2539 // type, e.g., could not get strides/offset.
2540 assert(succeeded(resultType) && "could not compute layout");
2541 build(builder, result, *resultType, src, reassociation);
2542}
2543
2544void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2545 ArrayRef<int64_t> resultShape, Value src,
2546 ArrayRef<ReassociationIndices> reassociation,
2547 ArrayRef<OpFoldResult> outputShape) {
2548 // Only ranked memref source values are supported.
2549 auto srcType = llvm::cast<MemRefType>(src.getType());
2550 FailureOr<MemRefType> resultType =
2551 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2552 // Failure of this assertion usually indicates a problem with the source
2553 // type, e.g., could not get strides/offset.
2554 assert(succeeded(resultType) && "could not compute layout");
2555 build(builder, result, *resultType, src, reassociation, outputShape);
2556}
2557
2558LogicalResult ExpandShapeOp::verify() {
2559 MemRefType srcType = getSrcType();
2560 MemRefType resultType = getResultType();
2561
2562 if (srcType.getRank() > resultType.getRank()) {
2563 auto r0 = srcType.getRank();
2564 auto r1 = resultType.getRank();
2565 return emitOpError("has source rank ")
2566 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2567 << r0 << " > " << r1 << ").";
2568 }
2569
2570 // Verify result shape.
2571 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2572 resultType.getShape(),
2573 getReassociationIndices(),
2574 /*allowMultipleDynamicDimsPerGroup=*/true)))
2575 return failure();
2576
2577 // Compute expected result type (including layout map).
2578 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2579 srcType, resultType.getShape(), getReassociationIndices());
2580 if (failed(expectedResultType))
2581 return emitOpError("invalid source layout map");
2582
2583 // Check actual result type.
2584 if (*expectedResultType != resultType)
2585 return emitOpError("expected expanded type to be ")
2586 << *expectedResultType << " but found " << resultType;
2587
2588 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2589 return emitOpError("expected number of static shape bounds to be equal to "
2590 "the output rank (")
2591 << resultType.getRank() << ") but found "
2592 << getStaticOutputShape().size() << " inputs instead";
2593
2594 if ((int64_t)getOutputShape().size() !=
2595 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2596 return emitOpError("mismatch in dynamic dims in output_shape and "
2597 "static_output_shape: static_output_shape has ")
2598 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2599 << " dynamic dims while output_shape has " << getOutputShape().size()
2600 << " values";
2601
2602 // Verify that the number of dynamic dims in output_shape matches the number
2603 // of dynamic dims in the result type.
2604 if (failed(verifyDynamicDimensionCount(getOperation(), resultType,
2605 getOutputShape())))
2606 return failure();
2607
2608 // Verify if provided output shapes are in agreement with output type.
2609 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2610 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2611 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2612 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2613 return emitOpError("invalid output shape provided at pos ") << pos;
2614 }
2615 }
2616
2617 return success();
2618}
2619
2620struct ExpandShapeOpMemRefCastFolder : public OpRewritePattern<ExpandShapeOp> {
2621public:
2622 using OpRewritePattern<ExpandShapeOp>::OpRewritePattern;
2623
2624 LogicalResult matchAndRewrite(ExpandShapeOp op,
2625 PatternRewriter &rewriter) const override {
2626 auto cast = op.getSrc().getDefiningOp<CastOp>();
2627 if (!cast)
2628 return failure();
2629
2630 if (!CastOp::canFoldIntoConsumerOp(cast))
2631 return failure();
2632
2633 SmallVector<OpFoldResult> originalOutputShape = op.getMixedOutputShape();
2634 SmallVector<OpFoldResult> newOutputShape = originalOutputShape;
2635 SmallVector<int64_t> newOutputShapeSizes;
2636
2637 // Convert output shape dims from dynamic to static where possible.
2638 for (auto [dimIdx, dimSize] : enumerate(originalOutputShape)) {
2639 std::optional<int64_t> sizeOpt = getConstantIntValue(dimSize);
2640 if (!sizeOpt.has_value()) {
2641 newOutputShapeSizes.push_back(ShapedType::kDynamic);
2642 continue;
2643 }
2644
2645 newOutputShapeSizes.push_back(sizeOpt.value());
2646 newOutputShape[dimIdx] = rewriter.getIndexAttr(sizeOpt.value());
2647 }
2648
2649 Value castSource = cast.getSource();
2650 auto castSourceType = llvm::cast<MemRefType>(castSource.getType());
2651 SmallVector<ReassociationIndices> reassociationIndices =
2652 op.getReassociationIndices();
2653 for (auto [idx, group] : llvm::enumerate(reassociationIndices)) {
2654 auto newOutputShapeSizesSlice =
2655 ArrayRef(newOutputShapeSizes).slice(group.front(), group.size());
2656 bool newOutputDynamic =
2657 llvm::is_contained(newOutputShapeSizesSlice, ShapedType::kDynamic);
2658 if (castSourceType.isDynamicDim(idx) != newOutputDynamic)
2659 return rewriter.notifyMatchFailure(
2660 op, "folding cast will result in changing dynamicity in "
2661 "reassociation group");
2662 }
2663
2664 FailureOr<MemRefType> newResultTypeOrFailure =
2665 ExpandShapeOp::computeExpandedType(castSourceType, newOutputShapeSizes,
2666 reassociationIndices);
2667
2668 if (failed(newResultTypeOrFailure))
2669 return rewriter.notifyMatchFailure(
2670 op, "could not compute new expanded type after folding cast");
2671
2672 if (*newResultTypeOrFailure == op.getResultType()) {
2673 rewriter.modifyOpInPlace(
2674 op, [&]() { op.getSrcMutable().assign(castSource); });
2675 } else {
2676 Value newOp = ExpandShapeOp::create(rewriter, op->getLoc(),
2677 *newResultTypeOrFailure, castSource,
2678 reassociationIndices, newOutputShape);
2679 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2680 }
2681 return success();
2682 }
2683};
2684
2685void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2686 MLIRContext *context) {
2687 results.add<
2688 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2689 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp, CastOp>,
2690 ExpandShapeOpMemRefCastFolder>(context);
2691}
2692
2693FailureOr<std::optional<SmallVector<Value>>>
2694ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2695 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2696}
2697
2698/// Compute the layout map after collapsing a given source MemRef type with the
2699/// specified reassociation indices.
2700///
2701/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2702/// not possible to check this by inspecting a MemRefType in the general case.
2703/// If non-contiguity cannot be checked statically, the collapse is assumed to
2704/// be valid (and thus accepted by this function) unless `strict = true`.
2705static FailureOr<StridedLayoutAttr>
2706computeCollapsedLayoutMap(MemRefType srcType,
2707 ArrayRef<ReassociationIndices> reassociation,
2708 bool strict = false) {
2709 int64_t srcOffset;
2710 SmallVector<int64_t> srcStrides;
2711 auto srcShape = srcType.getShape();
2712 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2713 return failure();
2714
2715 // The result stride of a reassociation group is the stride of the last entry
2716 // of the reassociation. (TODO: Should be the minimum stride in the
2717 // reassociation because strides are not necessarily sorted. E.g., when using
2718 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2719 // strides are meaningless and could have any arbitrary value.
2720 SmallVector<int64_t> resultStrides;
2721 resultStrides.reserve(reassociation.size());
2722 for (const ReassociationIndices &reassoc : reassociation) {
2723 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2724 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2725 ref = ref.drop_back();
2726 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2727 resultStrides.push_back(srcStrides[ref.back()]);
2728 } else {
2729 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2730 // the corresponding stride may have to be skipped. (See above comment.)
2731 // Therefore, the result stride cannot be statically determined and must
2732 // be dynamic.
2733 resultStrides.push_back(ShapedType::kDynamic);
2734 }
2735 }
2736
2737 // Validate that each reassociation group is contiguous.
2738 unsigned resultStrideIndex = resultStrides.size() - 1;
2739 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2740 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2741 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2742 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2743 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2744
2745 // Dimensions of size 1 should be skipped, because their strides are
2746 // meaningless and could have any arbitrary value.
2747 if (srcShape[idx - 1] == 1)
2748 continue;
2749
2750 // Both source and result stride must have the same static value. In that
2751 // case, we can be sure, that the dimensions are collapsible (because they
2752 // are contiguous).
2753 // If `strict = false` (default during op verification), we accept cases
2754 // where one or both strides are dynamic. This is best effort: We reject
2755 // ops where obviously non-contiguous dims are collapsed, but accept ops
2756 // where we cannot be sure statically. Such ops may fail at runtime. See
2757 // the op documentation for details.
2758 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2759 if (strict && (stride.saturated || srcStride.saturated))
2760 return failure();
2761
2762 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2763 return failure();
2764 }
2765 }
2766 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2767}
2768
2769bool CollapseShapeOp::isGuaranteedCollapsible(
2770 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2771 // MemRefs with identity layout are always collapsible.
2772 if (srcType.getLayout().isIdentity())
2773 return true;
2774
2775 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2776 /*strict=*/true));
2777}
2778
2779MemRefType CollapseShapeOp::computeCollapsedType(
2780 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2781 SmallVector<int64_t> resultShape;
2782 resultShape.reserve(reassociation.size());
2783 for (const ReassociationIndices &group : reassociation) {
2784 auto groupSize = SaturatedInteger::wrap(1);
2785 for (int64_t srcDim : group)
2786 groupSize =
2787 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2788 resultShape.push_back(groupSize.asInteger());
2789 }
2790
2791 if (srcType.getLayout().isIdentity()) {
2792 // If the source is contiguous (i.e., no layout map specified), so is the
2793 // result.
2794 MemRefLayoutAttrInterface layout;
2795 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2796 srcType.getMemorySpace());
2797 }
2798
2799 // Source may not be fully contiguous. Compute the layout map.
2800 // Note: Dimensions that are collapsed into a single dim are assumed to be
2801 // contiguous.
2802 FailureOr<StridedLayoutAttr> computedLayout =
2803 computeCollapsedLayoutMap(srcType, reassociation);
2804 assert(succeeded(computedLayout) &&
2805 "invalid source layout map or collapsing non-contiguous dims");
2806 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2807 srcType.getMemorySpace());
2808}
2809
2810void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2811 ArrayRef<ReassociationIndices> reassociation,
2812 ArrayRef<NamedAttribute> attrs) {
2813 auto srcType = llvm::cast<MemRefType>(src.getType());
2814 MemRefType resultType =
2815 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2817 getReassociationIndicesAttribute(b, reassociation));
2818 build(b, result, resultType, src, attrs);
2819}
2820
2821LogicalResult CollapseShapeOp::verify() {
2822 MemRefType srcType = getSrcType();
2823 MemRefType resultType = getResultType();
2824
2825 if (srcType.getRank() < resultType.getRank()) {
2826 auto r0 = srcType.getRank();
2827 auto r1 = resultType.getRank();
2828 return emitOpError("has source rank ")
2829 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2830 << r0 << " < " << r1 << ").";
2831 }
2832
2833 // Verify result shape.
2834 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2835 srcType.getShape(), getReassociationIndices(),
2836 /*allowMultipleDynamicDimsPerGroup=*/true)))
2837 return failure();
2838
2839 // Compute expected result type (including layout map).
2840 MemRefType expectedResultType;
2841 if (srcType.getLayout().isIdentity()) {
2842 // If the source is contiguous (i.e., no layout map specified), so is the
2843 // result.
2844 MemRefLayoutAttrInterface layout;
2845 expectedResultType =
2846 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2847 srcType.getMemorySpace());
2848 } else {
2849 // Source may not be fully contiguous. Compute the layout map.
2850 // Note: Dimensions that are collapsed into a single dim are assumed to be
2851 // contiguous.
2852 FailureOr<StridedLayoutAttr> computedLayout =
2853 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2854 if (failed(computedLayout))
2855 return emitOpError(
2856 "invalid source layout map or collapsing non-contiguous dims");
2857 expectedResultType =
2858 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2859 *computedLayout, srcType.getMemorySpace());
2860 }
2861
2862 if (expectedResultType != resultType)
2863 return emitOpError("expected collapsed type to be ")
2864 << expectedResultType << " but found " << resultType;
2865
2866 return success();
2867}
2868
2870 : public OpRewritePattern<CollapseShapeOp> {
2871public:
2872 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2873
2874 LogicalResult matchAndRewrite(CollapseShapeOp op,
2875 PatternRewriter &rewriter) const override {
2876 auto cast = op.getOperand().getDefiningOp<CastOp>();
2877 if (!cast)
2878 return failure();
2879
2880 if (!CastOp::canFoldIntoConsumerOp(cast))
2881 return failure();
2882
2883 Type newResultType = CollapseShapeOp::computeCollapsedType(
2884 llvm::cast<MemRefType>(cast.getOperand().getType()),
2885 op.getReassociationIndices());
2886
2887 if (newResultType == op.getResultType()) {
2888 rewriter.modifyOpInPlace(
2889 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2890 } else {
2891 Value newOp =
2892 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2893 op.getReassociationIndices());
2894 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2895 }
2896 return success();
2897 }
2898};
2899
2900void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2901 MLIRContext *context) {
2902 results.add<
2903 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2904 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2905 memref::DimOp, MemRefType>,
2906 CollapseShapeOpMemRefCastFolder>(context);
2907}
2908
2909OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2911 adaptor.getOperands());
2912}
2913
2914OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2916 adaptor.getOperands());
2917}
2918
2919FailureOr<std::optional<SmallVector<Value>>>
2920CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2921 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2922}
2923
2924//===----------------------------------------------------------------------===//
2925// ReshapeOp
2926//===----------------------------------------------------------------------===//
2927
2928void ReshapeOp::getAsmResultNames(
2929 function_ref<void(Value, StringRef)> setNameFn) {
2930 setNameFn(getResult(), "reshape");
2931}
2932
2933LogicalResult ReshapeOp::verify() {
2934 Type operandType = getSource().getType();
2935 Type resultType = getResult().getType();
2936
2937 Type operandElementType =
2938 llvm::cast<ShapedType>(operandType).getElementType();
2939 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2940 if (operandElementType != resultElementType)
2941 return emitOpError("element types of source and destination memref "
2942 "types should be the same");
2943
2944 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2945 if (!operandMemRefType.getLayout().isIdentity())
2946 return emitOpError("source memref type should have identity affine map");
2947
2948 int64_t shapeSize =
2949 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2950 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2951 if (resultMemRefType) {
2952 if (!resultMemRefType.getLayout().isIdentity())
2953 return emitOpError("result memref type should have identity affine map");
2954 if (shapeSize == ShapedType::kDynamic)
2955 return emitOpError("cannot use shape operand with dynamic length to "
2956 "reshape to statically-ranked memref type");
2957 if (shapeSize != resultMemRefType.getRank())
2958 return emitOpError(
2959 "length of shape operand differs from the result's memref rank");
2960 }
2961 return success();
2962}
2963
2964FailureOr<std::optional<SmallVector<Value>>>
2965ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2966 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2967}
2968
2969//===----------------------------------------------------------------------===//
2970// StoreOp
2971//===----------------------------------------------------------------------===//
2972
2973LogicalResult StoreOp::verify() {
2974 if (getNumOperands() != 2 + getMemRefType().getRank())
2975 return emitOpError("store index operand count not equal to memref rank");
2976
2977 return success();
2978}
2979
2980LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2981 SmallVectorImpl<OpFoldResult> &results) {
2982 /// store(memrefcast) -> store
2983 return foldMemRefCast(*this, getValueToStore());
2984}
2985
2986FailureOr<std::optional<SmallVector<Value>>>
2987StoreOp::bubbleDownCasts(OpBuilder &builder) {
2989 ValueRange());
2990}
2991
2992//===----------------------------------------------------------------------===//
2993// SubViewOp
2994//===----------------------------------------------------------------------===//
2995
2996void SubViewOp::getAsmResultNames(
2997 function_ref<void(Value, StringRef)> setNameFn) {
2998 setNameFn(getResult(), "subview");
2999}
3000
3001/// A subview result type can be fully inferred from the source type and the
3002/// static representation of offsets, sizes and strides. Special sentinels
3003/// encode the dynamic case.
3004MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3005 ArrayRef<int64_t> staticOffsets,
3006 ArrayRef<int64_t> staticSizes,
3007 ArrayRef<int64_t> staticStrides) {
3008 unsigned rank = sourceMemRefType.getRank();
3009 (void)rank;
3010 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
3011 assert(staticSizes.size() == rank && "staticSizes length mismatch");
3012 assert(staticStrides.size() == rank && "staticStrides length mismatch");
3013
3014 // Extract source offset and strides.
3015 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
3016
3017 // Compute target offset whose value is:
3018 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
3019 int64_t targetOffset = sourceOffset;
3020 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
3021 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
3022 targetOffset = (SaturatedInteger::wrap(targetOffset) +
3023 SaturatedInteger::wrap(staticOffset) *
3024 SaturatedInteger::wrap(sourceStride))
3025 .asInteger();
3026 }
3027
3028 // Compute target stride whose value is:
3029 // `sourceStrides_i * staticStrides_i`.
3030 SmallVector<int64_t, 4> targetStrides;
3031 targetStrides.reserve(staticOffsets.size());
3032 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
3033 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
3034 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
3035 SaturatedInteger::wrap(staticStride))
3036 .asInteger());
3037 }
3038
3039 // The type is now known.
3040 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
3041 StridedLayoutAttr::get(sourceMemRefType.getContext(),
3042 targetOffset, targetStrides),
3043 sourceMemRefType.getMemorySpace());
3044}
3045
3046MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
3047 ArrayRef<OpFoldResult> offsets,
3048 ArrayRef<OpFoldResult> sizes,
3049 ArrayRef<OpFoldResult> strides) {
3050 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3051 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3052 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3053 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3054 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3055 if (!hasValidSizesOffsets(staticOffsets))
3056 return {};
3057 if (!hasValidSizesOffsets(staticSizes))
3058 return {};
3059 if (!hasValidStrides(staticStrides))
3060 return {};
3061 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3062 staticSizes, staticStrides);
3063}
3064
3065MemRefType SubViewOp::inferRankReducedResultType(
3066 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3067 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3068 ArrayRef<int64_t> strides) {
3069 MemRefType inferredType =
3070 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
3071 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
3072 "expected ");
3073 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
3074 return inferredType;
3075
3076 // Compute which dimensions are dropped.
3077 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
3078 computeRankReductionMask(inferredType.getShape(), resultShape);
3079 assert(dimsToProject.has_value() && "invalid rank reduction");
3080
3081 // Compute the layout and result type.
3082 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
3083 SmallVector<int64_t> rankReducedStrides;
3084 rankReducedStrides.reserve(resultShape.size());
3085 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
3086 if (!dimsToProject->contains(idx))
3087 rankReducedStrides.push_back(value);
3088 }
3089 return MemRefType::get(resultShape, inferredType.getElementType(),
3090 StridedLayoutAttr::get(inferredLayout.getContext(),
3091 inferredLayout.getOffset(),
3092 rankReducedStrides),
3093 inferredType.getMemorySpace());
3094}
3095
3096MemRefType SubViewOp::inferRankReducedResultType(
3097 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
3098 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
3099 ArrayRef<OpFoldResult> strides) {
3100 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3101 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3102 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3103 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3104 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3105 return SubViewOp::inferRankReducedResultType(
3106 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
3107 staticStrides);
3108}
3109
3110// Build a SubViewOp with mixed static and dynamic entries and custom result
3111// type. If the type passed is nullptr, it is inferred.
3112void SubViewOp::build(OpBuilder &b, OperationState &result,
3113 MemRefType resultType, Value source,
3114 ArrayRef<OpFoldResult> offsets,
3115 ArrayRef<OpFoldResult> sizes,
3116 ArrayRef<OpFoldResult> strides,
3117 ArrayRef<NamedAttribute> attrs) {
3118 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
3119 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
3120 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
3121 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
3122 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
3123 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
3124 // Structuring implementation this way avoids duplication between builders.
3125 if (!resultType) {
3126 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
3127 staticSizes, staticStrides);
3128 }
3129 result.addAttributes(attrs);
3130 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
3131 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
3132 b.getDenseI64ArrayAttr(staticSizes),
3133 b.getDenseI64ArrayAttr(staticStrides));
3134}
3135
3136// Build a SubViewOp with mixed static and dynamic entries and inferred result
3137// type.
3138void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3139 ArrayRef<OpFoldResult> offsets,
3140 ArrayRef<OpFoldResult> sizes,
3141 ArrayRef<OpFoldResult> strides,
3142 ArrayRef<NamedAttribute> attrs) {
3143 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3144}
3145
3146// Build a SubViewOp with static entries and inferred result type.
3147void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3148 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3149 ArrayRef<int64_t> strides,
3150 ArrayRef<NamedAttribute> attrs) {
3151 SmallVector<OpFoldResult> offsetValues =
3152 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3153 return b.getI64IntegerAttr(v);
3154 });
3155 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3156 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3157 SmallVector<OpFoldResult> strideValues =
3158 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3159 return b.getI64IntegerAttr(v);
3160 });
3161 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
3162}
3163
3164// Build a SubViewOp with dynamic entries and custom result type. If the
3165// type passed is nullptr, it is inferred.
3166void SubViewOp::build(OpBuilder &b, OperationState &result,
3167 MemRefType resultType, Value source,
3168 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
3169 ArrayRef<int64_t> strides,
3170 ArrayRef<NamedAttribute> attrs) {
3171 SmallVector<OpFoldResult> offsetValues =
3172 llvm::map_to_vector<4>(offsets, [&](int64_t v) -> OpFoldResult {
3173 return b.getI64IntegerAttr(v);
3174 });
3175 SmallVector<OpFoldResult> sizeValues = llvm::map_to_vector<4>(
3176 sizes, [&](int64_t v) -> OpFoldResult { return b.getI64IntegerAttr(v); });
3177 SmallVector<OpFoldResult> strideValues =
3178 llvm::map_to_vector<4>(strides, [&](int64_t v) -> OpFoldResult {
3179 return b.getI64IntegerAttr(v);
3180 });
3181 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3182 attrs);
3183}
3184
3185// Build a SubViewOp with dynamic entries and custom result type. If the type
3186// passed is nullptr, it is inferred.
3187void SubViewOp::build(OpBuilder &b, OperationState &result,
3188 MemRefType resultType, Value source, ValueRange offsets,
3189 ValueRange sizes, ValueRange strides,
3190 ArrayRef<NamedAttribute> attrs) {
3191 SmallVector<OpFoldResult> offsetValues = llvm::map_to_vector<4>(
3192 offsets, [](Value v) -> OpFoldResult { return v; });
3193 SmallVector<OpFoldResult> sizeValues =
3194 llvm::map_to_vector<4>(sizes, [](Value v) -> OpFoldResult { return v; });
3195 SmallVector<OpFoldResult> strideValues = llvm::map_to_vector<4>(
3196 strides, [](Value v) -> OpFoldResult { return v; });
3197 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3198}
3199
3200// Build a SubViewOp with dynamic entries and inferred result type.
3201void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3202 ValueRange offsets, ValueRange sizes, ValueRange strides,
3203 ArrayRef<NamedAttribute> attrs) {
3204 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3205}
3206
3207/// For ViewLikeOpInterface.
3208Value SubViewOp::getViewSource() { return getSource(); }
3209
3210/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3211/// static value).
3212static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3213 int64_t t1Offset, t2Offset;
3214 SmallVector<int64_t> t1Strides, t2Strides;
3215 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3216 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3217 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3218}
3219
3220/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3221/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3222/// marked as dropped in `droppedDims`.
3223static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3224 const llvm::SmallBitVector &droppedDims) {
3225 assert(size_t(t1.getRank()) == droppedDims.size() &&
3226 "incorrect number of bits");
3227 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3228 "incorrect number of dropped dims");
3229 int64_t t1Offset, t2Offset;
3230 SmallVector<int64_t> t1Strides, t2Strides;
3231 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3232 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3233 if (failed(res1) || failed(res2))
3234 return false;
3235 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3236 if (droppedDims[i])
3237 continue;
3238 if (t1Strides[i] != t2Strides[j])
3239 return false;
3240 ++j;
3241 }
3242 return true;
3243}
3244
3246 SubViewOp op, Type expectedType) {
3247 auto memrefType = llvm::cast<ShapedType>(expectedType);
3248 switch (result) {
3250 return success();
3252 return op->emitError("expected result rank to be smaller or equal to ")
3253 << "the source rank, but got " << op.getType();
3255 return op->emitError("expected result type to be ")
3256 << expectedType
3257 << " or a rank-reduced version. (mismatch of result sizes), but got "
3258 << op.getType();
3260 return op->emitError("expected result element type to be ")
3261 << memrefType.getElementType() << ", but got " << op.getType();
3263 return op->emitError(
3264 "expected result and source memory spaces to match, but got ")
3265 << op.getType();
3267 return op->emitError("expected result type to be ")
3268 << expectedType
3269 << " or a rank-reduced version. (mismatch of result layout), but "
3270 "got "
3271 << op.getType();
3272 }
3273 llvm_unreachable("unexpected subview verification result");
3274}
3275
3276/// Verifier for SubViewOp.
3277LogicalResult SubViewOp::verify() {
3278 MemRefType baseType = getSourceType();
3279 MemRefType subViewType = getType();
3280 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3281 ArrayRef<int64_t> staticSizes = getStaticSizes();
3282 ArrayRef<int64_t> staticStrides = getStaticStrides();
3283
3284 // The base memref and the view memref should be in the same memory space.
3285 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3286 return emitError("different memory spaces specified for base memref "
3287 "type ")
3288 << baseType << " and subview memref type " << subViewType;
3289
3290 // Verify that the base memref type has a strided layout map.
3291 if (!baseType.isStrided())
3292 return emitError("base type ") << baseType << " is not strided";
3293
3294 // Compute the expected result type, assuming that there are no rank
3295 // reductions.
3296 MemRefType expectedType = SubViewOp::inferResultType(
3297 baseType, staticOffsets, staticSizes, staticStrides);
3298
3299 // Verify all properties of a shaped type: rank, element type and dimension
3300 // sizes. This takes into account potential rank reductions.
3301 auto shapedTypeVerification = isRankReducedType(
3302 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3303 if (shapedTypeVerification != SliceVerificationResult::Success)
3304 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3305
3306 // Make sure that the memory space did not change.
3307 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3309 *this, expectedType);
3310
3311 // Verify the offset of the layout map.
3312 if (!haveCompatibleOffsets(expectedType, subViewType))
3314 *this, expectedType);
3315
3316 // The only thing that's left to verify now are the strides. First, compute
3317 // the unused dimensions due to rank reductions. We have to look at sizes and
3318 // strides to decide which dimensions were dropped. This function also
3319 // partially verifies strides in case of rank reductions.
3320 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3321 getMixedSizes());
3322 if (failed(unusedDims))
3324 *this, expectedType);
3325
3326 // Strides must match.
3327 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3329 *this, expectedType);
3330
3331 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3332 // to the base memref.
3333 SliceBoundsVerificationResult boundsResult =
3334 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3335 staticStrides, /*generateErrorMessage=*/true);
3336 if (!boundsResult.isValid)
3337 return getOperation()->emitError(boundsResult.errorMessage);
3338
3339 return success();
3340}
3341
3343 return os << "range " << range.offset << ":" << range.size << ":"
3344 << range.stride;
3345}
3346
3347/// Return the list of Range (i.e. offset, size, stride). Each Range
3348/// entry contains either the dynamic value or a ConstantIndexOp constructed
3349/// with `b` at location `loc`.
3350SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3351 OpBuilder &b, Location loc) {
3352 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3353 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3354 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3356 unsigned rank = ranks[0];
3357 res.reserve(rank);
3358 for (unsigned idx = 0; idx < rank; ++idx) {
3359 Value offset =
3360 op.isDynamicOffset(idx)
3361 ? op.getDynamicOffset(idx)
3362 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3363 Value size =
3364 op.isDynamicSize(idx)
3365 ? op.getDynamicSize(idx)
3366 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3367 Value stride =
3368 op.isDynamicStride(idx)
3369 ? op.getDynamicStride(idx)
3370 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3371 res.emplace_back(Range{offset, size, stride});
3372 }
3373 return res;
3374}
3375
3376/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3377/// to deduce the result type for the given `sourceType`. Additionally, reduce
3378/// the rank of the inferred result type if `currentResultType` is lower rank
3379/// than `currentSourceType`. Use this signature if `sourceType` is updated
3380/// together with the result type. In this case, it is important to compute
3381/// the dropped dimensions using `currentSourceType` whose strides align with
3382/// `currentResultType`.
3384 MemRefType currentResultType, MemRefType currentSourceType,
3385 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3386 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3387 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3388 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3389 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3390 currentSourceType, currentResultType, mixedSizes);
3391 if (failed(unusedDims))
3392 return nullptr;
3393
3394 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3395 SmallVector<int64_t> shape, strides;
3396 unsigned numDimsAfterReduction =
3397 nonRankReducedType.getRank() - unusedDims->count();
3398 shape.reserve(numDimsAfterReduction);
3399 strides.reserve(numDimsAfterReduction);
3400 for (const auto &[idx, size, stride] :
3401 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3402 nonRankReducedType.getShape(), layout.getStrides())) {
3403 if (unusedDims->test(idx))
3404 continue;
3405 shape.push_back(size);
3406 strides.push_back(stride);
3407 }
3408
3409 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3410 StridedLayoutAttr::get(sourceType.getContext(),
3411 layout.getOffset(), strides),
3412 nonRankReducedType.getMemorySpace());
3413}
3414
3416 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3417 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3418 unsigned rank = memrefType.getRank();
3419 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3421 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3422 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3423 targetShape, memrefType, offsets, sizes, strides);
3424 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3425 sizes, strides);
3426}
3427
3428FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3429 Value value,
3430 ArrayRef<int64_t> desiredShape) {
3431 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3432 assert(sourceMemrefType && "not a ranked memref type");
3433 auto sourceShape = sourceMemrefType.getShape();
3434 if (sourceShape.equals(desiredShape))
3435 return value;
3436 auto maybeRankReductionMask =
3437 mlir::computeRankReductionMask(sourceShape, desiredShape);
3438 if (!maybeRankReductionMask)
3439 return failure();
3440 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3441}
3442
3443/// Helper method to check if a `subview` operation is trivially a no-op. This
3444/// is the case if the all offsets are zero, all strides are 1, and the source
3445/// shape is same as the size of the subview. In such cases, the subview can
3446/// be folded into its source.
3447static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3448 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3449 return false;
3450
3451 auto mixedOffsets = subViewOp.getMixedOffsets();
3452 auto mixedSizes = subViewOp.getMixedSizes();
3453 auto mixedStrides = subViewOp.getMixedStrides();
3454
3455 // Check offsets are zero.
3456 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3457 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3458 return !intValue || intValue.value() != 0;
3459 }))
3460 return false;
3461
3462 // Check strides are one.
3463 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3464 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3465 return !intValue || intValue.value() != 1;
3466 }))
3467 return false;
3468
3469 // Check all size values are static and matches the (static) source shape.
3470 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3471 for (const auto &size : llvm::enumerate(mixedSizes)) {
3472 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3473 if (!intValue || *intValue != sourceShape[size.index()])
3474 return false;
3475 }
3476 // All conditions met. The `SubViewOp` is foldable as a no-op.
3477 return true;
3478}
3479
3480namespace {
3481/// Pattern to rewrite a subview op with MemRefCast arguments.
3482/// This essentially pushes memref.cast past its consuming subview when
3483/// `canFoldIntoConsumerOp` is true.
3484///
3485/// Example:
3486/// ```
3487/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3488/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3489/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3490/// ```
3491/// is rewritten into:
3492/// ```
3493/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3494/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3495/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3496/// ```
3497class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3498public:
3499 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3500
3501 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3502 PatternRewriter &rewriter) const override {
3503 // Any constant operand, just return to let SubViewOpConstantFolder kick
3504 // in.
3505 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3506 return matchPattern(operand, matchConstantIndex());
3507 }))
3508 return failure();
3509
3510 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3511 if (!castOp)
3512 return failure();
3513
3514 if (!CastOp::canFoldIntoConsumerOp(castOp))
3515 return failure();
3516
3517 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3518 // the MemRefCastOp source operand type to infer the result type and the
3519 // current SubViewOp source operand type to compute the dropped dimensions
3520 // if the operation is rank-reducing.
3521 auto resultType = getCanonicalSubViewResultType(
3522 subViewOp.getType(), subViewOp.getSourceType(),
3523 llvm::cast<MemRefType>(castOp.getSource().getType()),
3524 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3525 subViewOp.getMixedStrides());
3526 if (!resultType)
3527 return failure();
3528
3529 Value newSubView = SubViewOp::create(
3530 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3531 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3532 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3533 subViewOp.getStaticStrides());
3534 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3535 newSubView);
3536 return success();
3537 }
3538};
3539
3540/// Canonicalize subview ops that are no-ops. When the source shape is not
3541/// same as a result shape due to use of `affine_map`.
3542class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3543public:
3544 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3545
3546 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3547 PatternRewriter &rewriter) const override {
3548 if (!isTrivialSubViewOp(subViewOp))
3549 return failure();
3550 if (subViewOp.getSourceType() == subViewOp.getType()) {
3551 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3552 return success();
3553 }
3554 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3555 subViewOp.getSource());
3556 return success();
3557 }
3558};
3559} // namespace
3560
3561/// Return the canonical type of the result of a subview.
3563 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3564 ArrayRef<OpFoldResult> mixedSizes,
3565 ArrayRef<OpFoldResult> mixedStrides) {
3566 // Infer a memref type without taking into account any rank reductions.
3567 MemRefType resTy = SubViewOp::inferResultType(
3568 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3569 if (!resTy)
3570 return {};
3571 MemRefType nonReducedType = resTy;
3572
3573 // Directly return the non-rank reduced type if there are no dropped dims.
3574 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3575 if (droppedDims.none())
3576 return nonReducedType;
3577
3578 // Take the strides and offset from the non-rank reduced type.
3579 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3580
3581 // Drop dims from shape and strides.
3582 SmallVector<int64_t> targetShape;
3583 SmallVector<int64_t> targetStrides;
3584 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3585 if (droppedDims.test(i))
3586 continue;
3587 targetStrides.push_back(nonReducedStrides[i]);
3588 targetShape.push_back(nonReducedType.getDimSize(i));
3589 }
3590
3591 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3592 StridedLayoutAttr::get(nonReducedType.getContext(),
3593 offset, targetStrides),
3594 nonReducedType.getMemorySpace());
3595 }
3596};
3597
3598/// A canonicalizer wrapper to replace SubViewOps.
3600 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3601 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3602 }
3603};
3604
3605void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3606 MLIRContext *context) {
3607 results
3608 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3609 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3610 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3611}
3612
3613OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3614 MemRefType sourceMemrefType = getSource().getType();
3615 MemRefType resultMemrefType = getResult().getType();
3616 auto resultLayout =
3617 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3618
3619 if (resultMemrefType == sourceMemrefType &&
3620 resultMemrefType.hasStaticShape() &&
3621 (!resultLayout || resultLayout.hasStaticLayout())) {
3622 return getViewSource();
3623 }
3624
3625 // Fold subview(subview(x)), where both subviews have the same size and the
3626 // second subview's offsets are all zero. (I.e., the second subview is a
3627 // no-op.)
3628 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3629 auto srcSizes = srcSubview.getMixedSizes();
3630 auto sizes = getMixedSizes();
3631 auto offsets = getMixedOffsets();
3632 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3633 auto strides = getMixedStrides();
3634 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3635 bool allSizesSame = llvm::equal(sizes, srcSizes);
3636 if (allOffsetsZero && allStridesOne && allSizesSame &&
3637 resultMemrefType == sourceMemrefType)
3638 return getViewSource();
3639 }
3640
3641 return {};
3642}
3643
3644FailureOr<std::optional<SmallVector<Value>>>
3645SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3646 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3647}
3648
3649void SubViewOp::inferStridedMetadataRanges(
3650 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3651 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3652 auto isUninitialized =
3653 +[](IntegerValueRange range) { return range.isUninitialized(); };
3654
3655 // Bail early if any of the operands metadata is not ready:
3656 SmallVector<IntegerValueRange> offsetOperands =
3657 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3658 if (llvm::any_of(offsetOperands, isUninitialized))
3659 return;
3660
3661 SmallVector<IntegerValueRange> sizeOperands =
3662 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3663 if (llvm::any_of(sizeOperands, isUninitialized))
3664 return;
3665
3666 SmallVector<IntegerValueRange> stridesOperands =
3667 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3668 if (llvm::any_of(stridesOperands, isUninitialized))
3669 return;
3670
3671 StridedMetadataRange sourceRange =
3672 ranges[getSourceMutable().getOperandNumber()];
3673 if (sourceRange.isUninitialized())
3674 return;
3675
3676 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3677
3678 // Get the dropped dims.
3679 llvm::SmallBitVector droppedDims = getDroppedDims();
3680
3681 // Compute the new offset, strides and sizes.
3682 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3683 SmallVector<ConstantIntRanges> strides, sizes;
3684
3685 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3686 bool dropped = droppedDims.test(i);
3687 // Compute the new offset.
3688 ConstantIntRanges off =
3689 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3690 offset = intrange::inferAdd({offset, off});
3691
3692 // Skip dropped dimensions.
3693 if (dropped)
3694 continue;
3695 // Multiply the strides.
3696 strides.push_back(
3697 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3698 // Get the sizes.
3699 sizes.push_back(sizeOperands[i].getValue());
3700 }
3701
3702 setMetadata(getResult(),
3704 SmallVector<ConstantIntRanges>({std::move(offset)}),
3705 std::move(sizes), std::move(strides)));
3706}
3707
3708//===----------------------------------------------------------------------===//
3709// TransposeOp
3710//===----------------------------------------------------------------------===//
3711
3712void TransposeOp::getAsmResultNames(
3713 function_ref<void(Value, StringRef)> setNameFn) {
3714 setNameFn(getResult(), "transpose");
3715}
3716
3717/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3718static MemRefType inferTransposeResultType(MemRefType memRefType,
3719 AffineMap permutationMap) {
3720 auto originalSizes = memRefType.getShape();
3721 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3722 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3723
3724 // Compute permuted sizes and strides.
3725 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3726 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3727
3728 return MemRefType::Builder(memRefType)
3729 .setShape(sizes)
3730 .setLayout(
3731 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3732}
3733
3734void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3735 AffineMapAttr permutation,
3736 ArrayRef<NamedAttribute> attrs) {
3737 auto permutationMap = permutation.getValue();
3738 assert(permutationMap);
3739
3740 auto memRefType = llvm::cast<MemRefType>(in.getType());
3741 // Compute result type.
3742 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3743
3744 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3745 build(b, result, resultType, in, attrs);
3746}
3747
3748// transpose $in $permutation attr-dict : type($in) `to` type(results)
3749void TransposeOp::print(OpAsmPrinter &p) {
3750 p << " " << getIn() << " " << getPermutation();
3751 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3752 p << " : " << getIn().getType() << " to " << getType();
3753}
3754
3755ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3756 OpAsmParser::UnresolvedOperand in;
3757 AffineMap permutation;
3758 MemRefType srcType, dstType;
3759 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3760 parser.parseOptionalAttrDict(result.attributes) ||
3761 parser.parseColonType(srcType) ||
3762 parser.resolveOperand(in, srcType, result.operands) ||
3763 parser.parseKeywordType("to", dstType) ||
3764 parser.addTypeToList(dstType, result.types))
3765 return failure();
3766
3767 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3768 AffineMapAttr::get(permutation));
3769 return success();
3770}
3771
3772LogicalResult TransposeOp::verify() {
3773 if (!getPermutation().isPermutation())
3774 return emitOpError("expected a permutation map");
3775 if (getPermutation().getNumDims() != getIn().getType().getRank())
3776 return emitOpError("expected a permutation map of same rank as the input");
3777
3778 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3779 auto resultType = llvm::cast<MemRefType>(getType());
3780 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3781 .canonicalizeStridedLayout();
3782
3783 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3784 return emitOpError("result type ")
3785 << resultType
3786 << " is not equivalent to the canonical transposed input type "
3787 << canonicalResultType;
3788 return success();
3789}
3790
3791OpFoldResult TransposeOp::fold(FoldAdaptor) {
3792 // First check for identity permutation, we can fold it away if input and
3793 // result types are identical already.
3794 if (getPermutation().isIdentity() && getType() == getIn().getType())
3795 return getIn();
3796 // Fold two consecutive memref.transpose Ops into one by composing their
3797 // permutation maps.
3798 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3799 AffineMap composedPermutation =
3800 getPermutation().compose(otherTransposeOp.getPermutation());
3801 getInMutable().assign(otherTransposeOp.getIn());
3802 setPermutation(composedPermutation);
3803 return getResult();
3804 }
3805 return {};
3806}
3807
3808FailureOr<std::optional<SmallVector<Value>>>
3809TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3810 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3811}
3812
3813//===----------------------------------------------------------------------===//
3814// ViewOp
3815//===----------------------------------------------------------------------===//
3816
3817void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3818 setNameFn(getResult(), "view");
3819}
3820
3821LogicalResult ViewOp::verify() {
3822 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3823 auto viewType = getType();
3824
3825 // The base memref should have identity layout map (or none).
3826 if (!baseType.getLayout().isIdentity())
3827 return emitError("unsupported map for base memref type ") << baseType;
3828
3829 // The result memref should have identity layout map (or none).
3830 if (!viewType.getLayout().isIdentity())
3831 return emitError("unsupported map for result memref type ") << viewType;
3832
3833 // The base memref and the view memref should be in the same memory space.
3834 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3835 return emitError("different memory spaces specified for base memref "
3836 "type ")
3837 << baseType << " and view memref type " << viewType;
3838
3839 // Verify that we have the correct number of sizes for the result type.
3840 if (failed(verifyDynamicDimensionCount(getOperation(), viewType, getSizes())))
3841 return failure();
3842
3843 return success();
3844}
3845
3846Value ViewOp::getViewSource() { return getSource(); }
3847
3848OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3849 MemRefType sourceMemrefType = getSource().getType();
3850 MemRefType resultMemrefType = getResult().getType();
3851
3852 if (resultMemrefType == sourceMemrefType &&
3853 resultMemrefType.hasStaticShape() && isZeroInteger(getByteShift()))
3854 return getViewSource();
3855
3856 return {};
3857}
3858
3859SmallVector<OpFoldResult> ViewOp::getMixedSizes() {
3860 SmallVector<OpFoldResult> result;
3861 unsigned ctr = 0;
3862 Builder b(getContext());
3863 for (int64_t dim : getType().getShape()) {
3864 if (ShapedType::isDynamic(dim)) {
3865 result.push_back(getSizes()[ctr++]);
3866 } else {
3867 result.push_back(b.getIndexAttr(dim));
3868 }
3869 }
3870 return result;
3871}
3872
3873namespace {
3874/// Given a memref type and a range of values that defines its dynamic
3875/// dimension sizes, turn all dynamic sizes that have a constant value into
3876/// static dimension sizes.
3877static MemRefType
3878foldDynamicToStaticDimSizes(MemRefType type, ValueRange dynamicSizes,
3879 SmallVectorImpl<Value> &foldedDynamicSizes) {
3880 SmallVector<int64_t> staticShape(type.getShape());
3881 assert(type.getNumDynamicDims() == dynamicSizes.size() &&
3882 "incorrect number of dynamic sizes");
3883
3884 // Compute new static and dynamic sizes.
3885 unsigned ctr = 0;
3886 for (auto [dim, dimSize] : llvm::enumerate(type.getShape())) {
3887 if (ShapedType::isStatic(dimSize))
3888 continue;
3889
3890 Value dynamicSize = dynamicSizes[ctr++];
3891 if (auto cst = getConstantIntValue(dynamicSize)) {
3892 // Dynamic size must be non-negative.
3893 if (cst.value() < 0) {
3894 foldedDynamicSizes.push_back(dynamicSize);
3895 continue;
3896 }
3897 staticShape[dim] = cst.value();
3898 } else {
3899 foldedDynamicSizes.push_back(dynamicSize);
3900 }
3901 }
3902
3903 return MemRefType::Builder(type).setShape(staticShape);
3904}
3905
3906/// Change the result type of a `memref.view` by making originally dynamic
3907/// dimensions static when their sizes come from `constant` ops.
3908/// Example:
3909/// ```
3910/// %c5 = arith.constant 5: index
3911/// %0 = memref.view %src[%offset][%c5] : memref<?xi8> to memref<?x4xf32>
3912/// ```
3913/// to
3914/// ```
3915/// %0 = memref.view %src[%offset][] : memref<?xi8> to memref<5x4xf32>
3916/// ```
3917struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3918 using Base::Base;
3919
3920 LogicalResult matchAndRewrite(ViewOp viewOp,
3921 PatternRewriter &rewriter) const override {
3922 SmallVector<Value> foldedDynamicSizes;
3923 MemRefType resultType = viewOp.getType();
3924 MemRefType foldedMemRefType = foldDynamicToStaticDimSizes(
3925 resultType, viewOp.getSizes(), foldedDynamicSizes);
3926
3927 // Stop here if no dynamic size was promoted to static.
3928 if (foldedMemRefType == resultType)
3929 return failure();
3930
3931 // Create new ViewOp.
3932 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), foldedMemRefType,
3933 viewOp.getSource(), viewOp.getByteShift(),
3934 foldedDynamicSizes);
3935 // Insert a cast so we have the same type as the old memref type.
3936 rewriter.replaceOpWithNewOp<CastOp>(viewOp, resultType, newViewOp);
3937 return success();
3938 }
3939};
3940
3941/// view(memref.cast(%source)) -> view(%source).
3942struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3943 using Base::Base;
3944
3945 LogicalResult matchAndRewrite(ViewOp viewOp,
3946 PatternRewriter &rewriter) const override {
3947 auto memrefCastOp = viewOp.getSource().getDefiningOp<CastOp>();
3948 if (!memrefCastOp)
3949 return failure();
3950
3951 rewriter.replaceOpWithNewOp<ViewOp>(
3952 viewOp, viewOp.getType(), memrefCastOp.getSource(),
3953 viewOp.getByteShift(), viewOp.getSizes());
3954 return success();
3955 }
3956};
3957} // namespace
3958
3959void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3960 MLIRContext *context) {
3961 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3962}
3963
3964FailureOr<std::optional<SmallVector<Value>>>
3965ViewOp::bubbleDownCasts(OpBuilder &builder) {
3966 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3967}
3968
3969//===----------------------------------------------------------------------===//
3970// AtomicRMWOp
3971//===----------------------------------------------------------------------===//
3972
3973LogicalResult AtomicRMWOp::verify() {
3974 if (getMemRefType().getRank() != getNumOperands() - 2)
3975 return emitOpError(
3976 "expects the number of subscripts to be equal to memref rank");
3977 switch (getKind()) {
3978 case arith::AtomicRMWKind::addf:
3979 case arith::AtomicRMWKind::maximumf:
3980 case arith::AtomicRMWKind::minimumf:
3981 case arith::AtomicRMWKind::mulf:
3982 if (!llvm::isa<FloatType>(getValue().getType()))
3983 return emitOpError() << "with kind '"
3984 << arith::stringifyAtomicRMWKind(getKind())
3985 << "' expects a floating-point type";
3986 break;
3987 case arith::AtomicRMWKind::addi:
3988 case arith::AtomicRMWKind::maxs:
3989 case arith::AtomicRMWKind::maxu:
3990 case arith::AtomicRMWKind::mins:
3991 case arith::AtomicRMWKind::minu:
3992 case arith::AtomicRMWKind::muli:
3993 case arith::AtomicRMWKind::ori:
3994 case arith::AtomicRMWKind::xori:
3995 case arith::AtomicRMWKind::andi:
3996 if (!llvm::isa<IntegerType>(getValue().getType()))
3997 return emitOpError() << "with kind '"
3998 << arith::stringifyAtomicRMWKind(getKind())
3999 << "' expects an integer type";
4000 break;
4001 default:
4002 break;
4003 }
4004 return success();
4005}
4006
4007OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
4008 /// atomicrmw(memrefcast) -> atomicrmw
4009 if (succeeded(foldMemRefCast(*this, getValue())))
4010 return getResult();
4011 return OpFoldResult();
4012}
4013
4014FailureOr<std::optional<SmallVector<Value>>>
4015AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
4017 getResult());
4018}
4019
4020//===----------------------------------------------------------------------===//
4021// TableGen'd op method definitions
4022//===----------------------------------------------------------------------===//
4023
4024#define GET_OP_CLASSES
4025#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:59
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static LogicalResult foldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:98
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByStrides(MemRefType originalType, MemRefType reducedType, ArrayRef< int64_t > originalStrides, ArrayRef< int64_t > candidateStrides, llvm::SmallBitVector unusedDims)
Returns the set of source dimensions that are dropped in a rank reduction.
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMaskByPosition(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Returns the set of source dimensions that are dropped in a rank reduction.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static RankedTensorType foldDynamicToStaticDimSizes(RankedTensorType type, ValueRange dynamicSizes, SmallVector< Value > &foldedDynamicSizes)
Given a ranked tensor type and a range of values that defines its dynamic dimension sizes,...
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IndexType getIndexType()
Definition Builders.cpp:55
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:566
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:400
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:528
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:414
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:778
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:234
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:255
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:412
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:706
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:251
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:56
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:70
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:79
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
bool isZeroInteger(OpFoldResult v)
Return "true" if v is an integer value/attribute with constant value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:26
LogicalResult verifyElementTypesMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching element types.
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(ExpandShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.