MLIR 22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/Builders.h"
16#include "mlir/IR/Matchers.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallBitVector.h"
26
27using namespace mlir;
28using namespace mlir::memref;
29
30/// Materialize a single constant operation from a given attribute value with
31/// the desired resultant type.
32Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
33 Attribute value, Type type,
34 Location loc) {
35 return arith::ConstantOp::materialize(builder, value, type, loc);
36}
37
38//===----------------------------------------------------------------------===//
39// Common canonicalization pattern support logic
40//===----------------------------------------------------------------------===//
41
42/// This is a common class used for patterns of the form
43/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44/// into the root operation directly.
45LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46 bool folded = false;
47 for (OpOperand &operand : op->getOpOperands()) {
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
52 folded = true;
53 }
54 }
55 return success(folded);
56}
57
58/// Return an unranked/ranked tensor type for the given unranked/ranked memref
59/// type.
61 if (auto memref = llvm::dyn_cast<MemRefType>(type))
62 return RankedTensorType::get(memref.getShape(), memref.getElementType());
63 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64 return UnrankedTensorType::get(memref.getElementType());
65 return NoneType::get(type.getContext());
66}
67
69 int64_t dim) {
70 auto memrefType = llvm::cast<MemRefType>(value.getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.createOrFold<memref::DimOp>(loc, value, dim);
73
74 return builder.getIndexAttr(memrefType.getDimSize(dim));
75}
76
78 Location loc, Value value) {
79 auto memrefType = llvm::cast<MemRefType>(value.getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
82 result.push_back(getMixedSize(builder, loc, value, i));
83 return result;
84}
85
86//===----------------------------------------------------------------------===//
87// Utility functions for propagating static information
88//===----------------------------------------------------------------------===//
89
90/// Helper function that sets values[i] to constValues[i] if the latter is a
91/// static value, as indicated by ShapedType::kDynamic.
92///
93/// If constValues[i] is dynamic, tries to extract a constant value from
94/// value[i] to allow for additional folding opportunities. Also convertes all
95/// existing attributes to index attributes. (They may be i64 attributes.)
97 ArrayRef<int64_t> constValues) {
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
100 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101 Builder builder(values[i].getContext());
102 if (ShapedType::isStatic(cstVal)) {
103 // Constant value is known, use it directly.
104 values[i] = builder.getIndexAttr(cstVal);
105 continue;
106 }
107 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108 // Try to extract a constant or convert an existing to index.
109 values[i] = builder.getIndexAttr(*cst);
110 }
111 }
112}
113
114/// Helper function to retrieve a lossless memory-space cast, and the
115/// corresponding new result memref type.
116static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118 MemorySpaceCastOpInterface castOp =
119 MemorySpaceCastOpInterface::getIfPromotableCast(src);
120
121 // Bail if the cast is not lossless.
122 if (!castOp)
123 return {};
124
125 // Transform the source and target type of `castOp` to have the same metadata
126 // as `resultTy`. Bail if not possible.
127 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
128 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
129 if (failed(srcTy))
130 return {};
131
132 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
133 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
134 if (failed(tgtTy))
135 return {};
136
137 // Check if this is a valid memory-space cast.
138 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
139 return {};
140
141 return std::make_tuple(castOp, *tgtTy, *srcTy);
142}
143
144/// Implementation of `bubbleDownCasts` method for memref operations that
145/// return a single memref result.
146template <typename ConcreteOpTy>
147static FailureOr<std::optional<SmallVector<Value>>>
149 OpOperand &src) {
150 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
151 // Bail if we cannot cast.
152 if (!castOp)
153 return failure();
154
155 // Create the new operands.
156 SmallVector<Value> operands;
157 llvm::append_range(operands, op->getOperands());
158 operands[src.getOperandNumber()] = castOp.getSourcePtr();
159
160 // Create the new op and results.
161 auto newOp = ConcreteOpTy::create(
162 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
163 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
164
165 // Insert a memory-space cast to the original memory space of the op.
166 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
167 builder, tgtTy,
168 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
169 return std::optional<SmallVector<Value>>(
170 SmallVector<Value>({result.getTargetPtr()}));
171}
172
173//===----------------------------------------------------------------------===//
174// AllocOp / AllocaOp
175//===----------------------------------------------------------------------===//
176
177void AllocOp::getAsmResultNames(
178 function_ref<void(Value, StringRef)> setNameFn) {
179 setNameFn(getResult(), "alloc");
180}
181
182void AllocaOp::getAsmResultNames(
183 function_ref<void(Value, StringRef)> setNameFn) {
184 setNameFn(getResult(), "alloca");
185}
186
187template <typename AllocLikeOp>
188static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
189 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190 "applies to only alloc or alloca");
191 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
192 if (!memRefType)
193 return op.emitOpError("result must be a memref");
194
195 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196 return op.emitOpError("dimension operand count does not equal memref "
197 "dynamic dimension count");
198
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError("symbol operand count does not equal memref symbol "
204 "count: expected ")
205 << numSymbols << ", got " << op.getSymbolOperands().size();
206
207 return success();
208}
209
210LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
211
212LogicalResult AllocaOp::verify() {
213 // An alloca op needs to have an ancestor with an allocation scope trait.
214 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
215 return emitOpError(
216 "requires an ancestor op with AutomaticAllocationScope trait");
217
218 return verifyAllocLikeOp(*this);
219}
220
221namespace {
222/// Fold constant dimensions into an alloc like operation.
223template <typename AllocLikeOp>
224struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
225 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
226
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
228 PatternRewriter &rewriter) const override {
229 // Check to see if any dimensions operands are constants. If so, we can
230 // substitute and drop them.
231 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
232 APInt constSizeArg;
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
234 return false;
235 return constSizeArg.isNonNegative();
236 }))
237 return failure();
238
239 auto memrefType = alloc.getType();
240
241 // Ok, we have one or more constant operands. Collect the non-constant ones
242 // and keep track of the resultant memref type to build.
243 SmallVector<int64_t, 4> newShapeConstants;
244 newShapeConstants.reserve(memrefType.getRank());
245 SmallVector<Value, 4> dynamicSizes;
246
247 unsigned dynamicDimPos = 0;
248 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
250 // If this is already static dimension, keep it.
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
253 continue;
254 }
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
256 APInt constSizeArg;
257 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
258 constSizeArg.isNonNegative()) {
259 // Dynamic shape dimension will be folded.
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
261 } else {
262 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
265 }
266 dynamicDimPos++;
267 }
268
269 // Create new memref type (which will have fewer dynamic dimensions).
270 MemRefType newMemRefType =
271 MemRefType::Builder(memrefType).setShape(newShapeConstants);
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
273
274 // Create and insert the alloc op for the new memref.
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
278 // Insert a cast so we have the same type as the old alloc.
279 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
280 return success();
281 }
282};
283
284/// Fold alloc operations with no users or only store and dealloc uses.
285template <typename T>
286struct SimplifyDeadAlloc : public OpRewritePattern<T> {
287 using OpRewritePattern<T>::OpRewritePattern;
288
289 LogicalResult matchAndRewrite(T alloc,
290 PatternRewriter &rewriter) const override {
291 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
295 }))
296 return failure();
297
298 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
299 rewriter.eraseOp(user);
300
301 rewriter.eraseOp(alloc);
302 return success();
303 }
304};
305} // namespace
306
307void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
308 MLIRContext *context) {
309 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
310}
311
312void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
313 MLIRContext *context) {
314 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
315 context);
316}
317
318//===----------------------------------------------------------------------===//
319// ReallocOp
320//===----------------------------------------------------------------------===//
321
322LogicalResult ReallocOp::verify() {
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
324 MemRefType resultType = getType();
325
326 // The source memref should have identity layout (or none).
327 if (!sourceType.getLayout().isIdentity())
328 return emitError("unsupported layout for source memref type ")
329 << sourceType;
330
331 // The result memref should have identity layout (or none).
332 if (!resultType.getLayout().isIdentity())
333 return emitError("unsupported layout for result memref type ")
334 << resultType;
335
336 // The source memref and the result memref should be in the same memory space.
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError("different memory spaces specified for source memref "
339 "type ")
340 << sourceType << " and result memref type " << resultType;
341
342 // The source memref and the result memref should have the same element type.
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError("different element types specified for source memref "
345 "type ")
346 << sourceType << " and result memref type " << resultType;
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor(getOperation(), getResults()));
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415/// Given an operation, return whether this op is guaranteed to
416/// allocate an AutomaticAllocationScopeResource
418 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
419 if (!interface)
420 return false;
421 for (auto res : op->getResults()) {
422 if (auto effect =
423 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
424 if (isa<SideEffects::AutomaticAllocationScopeResource>(
425 effect->getResource()))
426 return true;
427 }
428 }
429 return false;
430}
431
432/// Given an operation, return whether this op itself could
433/// allocate an AutomaticAllocationScopeResource. Note that
434/// this will not check whether an operation contained within
435/// the op can allocate.
437 // This op itself doesn't create a stack allocation,
438 // the inner allocation should be handled separately.
440 return false;
441 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
442 if (!interface)
443 return true;
444 for (auto res : op->getResults()) {
445 if (auto effect =
446 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
447 if (isa<SideEffects::AutomaticAllocationScopeResource>(
448 effect->getResource()))
449 return true;
450 }
451 }
452 return false;
453}
454
455/// Return whether this op is the last non terminating op
456/// in a region. That is to say, it is in a one-block region
457/// and is only followed by a terminator. This prevents
458/// extending the lifetime of allocations.
460 return op->getBlock()->mightHaveTerminator() &&
461 op->getNextNode() == op->getBlock()->getTerminator() &&
463}
464
465/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
466/// or it contains no allocation.
467struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
468 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
469
470 LogicalResult matchAndRewrite(AllocaScopeOp op,
471 PatternRewriter &rewriter) const override {
472 bool hasPotentialAlloca =
473 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
474 if (alloc == op)
475 return WalkResult::advance();
477 return WalkResult::interrupt();
478 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
479 return WalkResult::skip();
480 return WalkResult::advance();
481 }).wasInterrupted();
482
483 // If this contains no potential allocation, it is always legal to
484 // inline. Otherwise, consider two conditions:
485 if (hasPotentialAlloca) {
486 // If the parent isn't an allocation scope, or we are not the last
487 // non-terminator op in the parent, we will extend the lifetime.
488 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
489 return failure();
491 return failure();
492 }
493
494 Block *block = &op.getRegion().front();
495 Operation *terminator = block->getTerminator();
496 ValueRange results = terminator->getOperands();
497 rewriter.inlineBlockBefore(block, op);
498 rewriter.replaceOp(op, results);
499 rewriter.eraseOp(terminator);
500 return success();
501 }
502};
503
504/// Move allocations into an allocation scope, if it is legal to
505/// move them (e.g. their operands are available at the location
506/// the op would be moved to).
507struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
508 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
509
510 LogicalResult matchAndRewrite(AllocaScopeOp op,
511 PatternRewriter &rewriter) const override {
512
513 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
514 return failure();
515
516 Operation *lastParentWithoutScope = op->getParentOp();
517
518 if (!lastParentWithoutScope ||
519 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
520 return failure();
521
522 // Only apply to if this is this last non-terminator
523 // op in the block (lest lifetime be extended) of a one
524 // block region
525 if (!lastNonTerminatorInRegion(op) ||
526 !lastNonTerminatorInRegion(lastParentWithoutScope))
527 return failure();
528
529 while (!lastParentWithoutScope->getParentOp()
531 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
532 if (!lastParentWithoutScope ||
533 !lastNonTerminatorInRegion(lastParentWithoutScope))
534 return failure();
535 }
536 assert(lastParentWithoutScope->getParentOp()
538
539 Region *containingRegion = nullptr;
540 for (auto &r : lastParentWithoutScope->getRegions()) {
541 if (r.isAncestor(op->getParentRegion())) {
542 assert(containingRegion == nullptr &&
543 "only one region can contain the op");
544 containingRegion = &r;
545 }
546 }
547 assert(containingRegion && "op must be contained in a region");
548
550 op->walk([&](Operation *alloc) {
552 return WalkResult::skip();
553
554 // If any operand is not defined before the location of
555 // lastParentWithoutScope (i.e. where we would hoist to), skip.
556 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
557 return containingRegion->isAncestor(v.getParentRegion());
558 }))
559 return WalkResult::skip();
560 toHoist.push_back(alloc);
561 return WalkResult::advance();
562 });
563
564 if (toHoist.empty())
565 return failure();
566 rewriter.setInsertionPoint(lastParentWithoutScope);
567 for (auto *op : toHoist) {
568 auto *cloned = rewriter.clone(*op);
569 rewriter.replaceOp(op, cloned->getResults());
570 }
571 return success();
572 }
573};
574
575void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
578}
579
580//===----------------------------------------------------------------------===//
581// AssumeAlignmentOp
582//===----------------------------------------------------------------------===//
583
584LogicalResult AssumeAlignmentOp::verify() {
585 if (!llvm::isPowerOf2_32(getAlignment()))
586 return emitOpError("alignment must be power of 2");
587 return success();
588}
589
590void AssumeAlignmentOp::getAsmResultNames(
591 function_ref<void(Value, StringRef)> setNameFn) {
592 setNameFn(getResult(), "assume_align");
593}
594
595OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
597 if (!source)
598 return {};
599 if (source.getAlignment() != getAlignment())
600 return {};
601 return getMemref();
602}
603
604FailureOr<std::optional<SmallVector<Value>>>
605AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
606 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
607}
608
609//===----------------------------------------------------------------------===//
610// DistinctObjectsOp
611//===----------------------------------------------------------------------===//
612
613LogicalResult DistinctObjectsOp::verify() {
614 if (getOperandTypes() != getResultTypes())
615 return emitOpError("operand types and result types must match");
616
617 if (getOperandTypes().empty())
618 return emitOpError("expected at least one operand");
619
620 return success();
621}
622
623LogicalResult DistinctObjectsOp::inferReturnTypes(
624 MLIRContext * /*context*/, std::optional<Location> /*location*/,
625 ValueRange operands, DictionaryAttr /*attributes*/,
626 OpaqueProperties /*properties*/, RegionRange /*regions*/,
627 SmallVectorImpl<Type> &inferredReturnTypes) {
628 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
629 return success();
630}
631
632//===----------------------------------------------------------------------===//
633// CastOp
634//===----------------------------------------------------------------------===//
635
636void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
637 setNameFn(getResult(), "cast");
638}
639
640/// Determines whether MemRef_CastOp casts to a more dynamic version of the
641/// source memref. This is useful to fold a memref.cast into a consuming op
642/// and implement canonicalization patterns for ops in different dialects that
643/// may consume the results of memref.cast operations. Such foldable memref.cast
644/// operations are typically inserted as `view` and `subview` ops are
645/// canonicalized, to preserve the type compatibility of their uses.
646///
647/// Returns true when all conditions are met:
648/// 1. source and result are ranked memrefs with strided semantics and same
649/// element type and rank.
650/// 2. each of the source's size, offset or stride has more static information
651/// than the corresponding result's size, offset or stride.
652///
653/// Example 1:
654/// ```mlir
655/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
656/// %2 = consumer %1 ... : memref<?x?xf32> ...
657/// ```
658///
659/// may fold into:
660///
661/// ```mlir
662/// %2 = consumer %0 ... : memref<8x16xf32> ...
663/// ```
664///
665/// Example 2:
666/// ```
667/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
668/// to memref<?x?xf32>
669/// consumer %1 : memref<?x?xf32> ...
670/// ```
671///
672/// may fold into:
673///
674/// ```
675/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
676/// ```
677bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
678 MemRefType sourceType =
679 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
681
682 // Requires ranked MemRefType.
683 if (!sourceType || !resultType)
684 return false;
685
686 // Requires same elemental type.
687 if (sourceType.getElementType() != resultType.getElementType())
688 return false;
689
690 // Requires same rank.
691 if (sourceType.getRank() != resultType.getRank())
692 return false;
693
694 // Only fold casts between strided memref forms.
695 int64_t sourceOffset, resultOffset;
696 SmallVector<int64_t, 4> sourceStrides, resultStrides;
697 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
699 return false;
700
701 // If cast is towards more static sizes along any dimension, don't fold.
702 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703 auto ss = std::get<0>(it), st = std::get<1>(it);
704 if (ss != st)
705 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
706 return false;
707 }
708
709 // If cast is towards more static offset along any dimension, don't fold.
710 if (sourceOffset != resultOffset)
711 if (ShapedType::isDynamic(sourceOffset) &&
712 ShapedType::isStatic(resultOffset))
713 return false;
714
715 // If cast is towards more static strides along any dimension, don't fold.
716 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
717 auto ss = std::get<0>(it), st = std::get<1>(it);
718 if (ss != st)
719 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
720 return false;
721 }
722
723 return true;
724}
725
726bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
727 if (inputs.size() != 1 || outputs.size() != 1)
728 return false;
729 Type a = inputs.front(), b = outputs.front();
730 auto aT = llvm::dyn_cast<MemRefType>(a);
731 auto bT = llvm::dyn_cast<MemRefType>(b);
732
733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
735
736 if (aT && bT) {
737 if (aT.getElementType() != bT.getElementType())
738 return false;
739 if (aT.getLayout() != bT.getLayout()) {
740 int64_t aOffset, bOffset;
741 SmallVector<int64_t, 4> aStrides, bStrides;
742 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744 aStrides.size() != bStrides.size())
745 return false;
746
747 // Strides along a dimension/offset are compatible if the value in the
748 // source memref is static and the value in the target memref is the
749 // same. They are also compatible if either one is dynamic (see
750 // description of MemRefCastOp for details).
751 auto checkCompatible = [](int64_t a, int64_t b) {
752 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
753 };
754 if (!checkCompatible(aOffset, bOffset))
755 return false;
756 for (const auto &aStride : enumerate(aStrides))
757 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
758 return false;
759 }
760 if (aT.getMemorySpace() != bT.getMemorySpace())
761 return false;
762
763 // They must have the same rank, and any specified dimensions must match.
764 if (aT.getRank() != bT.getRank())
765 return false;
766
767 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
770 aDim != bDim)
771 return false;
772 }
773 return true;
774 } else {
775 if (!aT && !uaT)
776 return false;
777 if (!bT && !ubT)
778 return false;
779 // Unranked to unranked casting is unsupported
780 if (uaT && ubT)
781 return false;
782
783 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785 if (aEltType != bEltType)
786 return false;
787
788 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790 return aMemSpace == bMemSpace;
791 }
792
793 return false;
794}
795
796OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
797 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
798}
799
800FailureOr<std::optional<SmallVector<Value>>>
801CastOp::bubbleDownCasts(OpBuilder &builder) {
802 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
803}
804
805//===----------------------------------------------------------------------===//
806// CopyOp
807//===----------------------------------------------------------------------===//
808
809namespace {
810
811/// Fold memref.copy(%x, %x).
812struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
813 using OpRewritePattern<CopyOp>::OpRewritePattern;
814
815 LogicalResult matchAndRewrite(CopyOp copyOp,
816 PatternRewriter &rewriter) const override {
817 if (copyOp.getSource() != copyOp.getTarget())
818 return failure();
819
820 rewriter.eraseOp(copyOp);
821 return success();
822 }
823};
824
825struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
826 using OpRewritePattern<CopyOp>::OpRewritePattern;
827
828 static bool isEmptyMemRef(BaseMemRefType type) {
829 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
830 }
831
832 LogicalResult matchAndRewrite(CopyOp copyOp,
833 PatternRewriter &rewriter) const override {
834 if (isEmptyMemRef(copyOp.getSource().getType()) ||
835 isEmptyMemRef(copyOp.getTarget().getType())) {
836 rewriter.eraseOp(copyOp);
837 return success();
838 }
839
840 return failure();
841 }
842};
843} // namespace
844
845void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
846 MLIRContext *context) {
847 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
848}
849
850/// If the source/target of a CopyOp is a CastOp that does not modify the shape
851/// and element type, the cast can be skipped. Such CastOps only cast the layout
852/// of the type.
853static LogicalResult FoldCopyOfCast(CopyOp op) {
854 for (OpOperand &operand : op->getOpOperands()) {
855 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
856 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
857 operand.set(castOp.getOperand());
858 return success();
859 }
860 }
861 return failure();
862}
863
864LogicalResult CopyOp::fold(FoldAdaptor adaptor,
866
867 /// copy(memrefcast) -> copy
868 return FoldCopyOfCast(*this);
869}
870
871//===----------------------------------------------------------------------===//
872// DeallocOp
873//===----------------------------------------------------------------------===//
874
875LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
877 /// dealloc(memrefcast) -> dealloc
878 return foldMemRefCast(*this);
879}
880
881//===----------------------------------------------------------------------===//
882// DimOp
883//===----------------------------------------------------------------------===//
884
885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886 setNameFn(getResult(), "dim");
887}
888
889void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890 int64_t index) {
891 auto loc = result.location;
892 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893 build(builder, result, source, indexValue);
894}
895
896std::optional<int64_t> DimOp::getConstantIndex() {
898}
899
900Speculation::Speculatability DimOp::getSpeculatability() {
901 auto constantIndex = getConstantIndex();
902 if (!constantIndex)
904
905 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
906 if (!rankedSourceType)
908
909 if (rankedSourceType.getRank() <= constantIndex)
911
913}
914
915void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916 SetIntLatticeFn setResultRange) {
917 setResultRange(getResult(),
918 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919}
920
921/// Return a map with key being elements in `vals` and data being number of
922/// occurences of it. Use std::map, since the `vals` here are strides and the
923/// dynamic stride value is the same as the tombstone value for
924/// `DenseMap<int64_t>`.
925static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
926 std::map<int64_t, unsigned> numOccurences;
927 for (auto val : vals)
928 numOccurences[val]++;
929 return numOccurences;
930}
931
932/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
933/// to be a subset of `originalType` with some `1` entries erased, return the
934/// set of indices that specifies which of the entries of `originalShape` are
935/// dropped to obtain `reducedShape`.
936/// This accounts for cases where there are multiple unit-dims, but only a
937/// subset of those are dropped. For MemRefTypes these can be disambiguated
938/// using the strides. If a dimension is dropped the stride must be dropped too.
939static FailureOr<llvm::SmallBitVector>
940computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
942 llvm::SmallBitVector unusedDims(originalType.getRank());
943 if (originalType.getRank() == reducedType.getRank())
944 return unusedDims;
945
946 for (const auto &dim : llvm::enumerate(sizes))
947 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949 unusedDims.set(dim.index());
950
951 // Early exit for the case where the number of unused dims matches the number
952 // of ranks reduced.
953 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
954 originalType.getRank())
955 return unusedDims;
956
957 SmallVector<int64_t> originalStrides, candidateStrides;
958 int64_t originalOffset, candidateOffset;
959 if (failed(
960 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
961 failed(
962 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
963 return failure();
964
965 // For memrefs, a dimension is truly dropped if its corresponding stride is
966 // also dropped. This is particularly important when more than one of the dims
967 // is 1. Track the number of occurences of the strides in the original type
968 // and the candidate type. For each unused dim that stride should not be
969 // present in the candidate type. Note that there could be multiple dimensions
970 // that have the same size. We dont need to exactly figure out which dim
971 // corresponds to which stride, we just need to verify that the number of
972 // reptitions of a stride in the original + number of unused dims with that
973 // stride == number of repititions of a stride in the candidate.
974 std::map<int64_t, unsigned> currUnaccountedStrides =
975 getNumOccurences(originalStrides);
976 std::map<int64_t, unsigned> candidateStridesNumOccurences =
977 getNumOccurences(candidateStrides);
978 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979 if (!unusedDims.test(dim))
980 continue;
981 int64_t originalStride = originalStrides[dim];
982 if (currUnaccountedStrides[originalStride] >
983 candidateStridesNumOccurences[originalStride]) {
984 // This dim can be treated as dropped.
985 currUnaccountedStrides[originalStride]--;
986 continue;
987 }
988 if (currUnaccountedStrides[originalStride] ==
989 candidateStridesNumOccurences[originalStride]) {
990 // The stride for this is not dropped. Keep as is.
991 unusedDims.reset(dim);
992 continue;
993 }
994 if (currUnaccountedStrides[originalStride] <
995 candidateStridesNumOccurences[originalStride]) {
996 // This should never happen. Cant have a stride in the reduced rank type
997 // that wasnt in the original one.
998 return failure();
999 }
1000 }
1001
1002 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003 originalType.getRank())
1004 return failure();
1005 return unusedDims;
1006}
1007
1008llvm::SmallBitVector SubViewOp::getDroppedDims() {
1009 MemRefType sourceType = getSourceType();
1010 MemRefType resultType = getType();
1011 FailureOr<llvm::SmallBitVector> unusedDims =
1012 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1013 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1014 return *unusedDims;
1015}
1016
1017OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1018 // All forms of folding require a known index.
1019 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1020 if (!index)
1021 return {};
1022
1023 // Folding for unranked types (UnrankedMemRefType) is not supported.
1024 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1025 if (!memrefType)
1026 return {};
1027
1028 // Out of bound indices produce undefined behavior but are still valid IR.
1029 // Don't choke on them.
1030 int64_t indexVal = index.getInt();
1031 if (indexVal < 0 || indexVal >= memrefType.getRank())
1032 return {};
1033
1034 // Fold if the shape extent along the given index is known.
1035 if (!memrefType.isDynamicDim(index.getInt())) {
1036 Builder builder(getContext());
1037 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1038 }
1039
1040 // The size at the given index is now known to be a dynamic size.
1041 unsigned unsignedIndex = index.getValue().getZExtValue();
1042
1043 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1044 Operation *definingOp = getSource().getDefiningOp();
1045
1046 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047 return *(alloc.getDynamicSizes().begin() +
1048 memrefType.getDynamicDimIndex(unsignedIndex));
1049
1050 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051 return *(alloca.getDynamicSizes().begin() +
1052 memrefType.getDynamicDimIndex(unsignedIndex));
1053
1054 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055 return *(view.getDynamicSizes().begin() +
1056 memrefType.getDynamicDimIndex(unsignedIndex));
1057
1058 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060 unsigned resultIndex = 0;
1061 unsigned sourceRank = subview.getSourceType().getRank();
1062 unsigned sourceIndex = 0;
1063 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064 if (unusedDims.test(i))
1065 continue;
1066 if (resultIndex == unsignedIndex) {
1067 sourceIndex = i;
1068 break;
1069 }
1070 resultIndex++;
1071 }
1072 assert(subview.isDynamicSize(sourceIndex) &&
1073 "expected dynamic subview size");
1074 return subview.getDynamicSize(sourceIndex);
1075 }
1076
1077 if (auto sizeInterface =
1078 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
1079 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
1080 "Expected dynamic subview size");
1081 return sizeInterface.getDynamicSize(unsignedIndex);
1082 }
1083
1084 // dim(memrefcast) -> dim
1085 if (succeeded(foldMemRefCast(*this)))
1086 return getResult();
1087
1088 return {};
1089}
1090
1091namespace {
1092/// Fold dim of a memref reshape operation to a load into the reshape's shape
1093/// operand.
1094struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1095 using OpRewritePattern<DimOp>::OpRewritePattern;
1096
1097 LogicalResult matchAndRewrite(DimOp dim,
1098 PatternRewriter &rewriter) const override {
1099 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1100
1101 if (!reshape)
1102 return rewriter.notifyMatchFailure(
1103 dim, "Dim op is not defined by a reshape op.");
1104
1105 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1106 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1107 // cheaply check that either of the following conditions hold:
1108 // 1. dim.getIndex() is defined in the same block as reshape but before
1109 // reshape.
1110 // 2. dim.getIndex() is defined in a parent block of
1111 // reshape.
1112
1113 // Check condition 1
1114 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1115 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1116 if (reshape->isBeforeInBlock(definingOp)) {
1117 return rewriter.notifyMatchFailure(
1118 dim,
1119 "dim.getIndex is not defined before reshape in the same block.");
1120 }
1121 } // else dim.getIndex is a block argument to reshape->getBlock and
1122 // dominates reshape
1123 } // Check condition 2
1124 else if (dim->getBlock() != reshape->getBlock() &&
1125 !dim.getIndex().getParentRegion()->isProperAncestor(
1126 reshape->getParentRegion())) {
1127 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1128 // already know dim.getIndex() dominates reshape without calling
1129 // `isProperAncestor`
1130 return rewriter.notifyMatchFailure(
1131 dim, "dim.getIndex does not dominate reshape.");
1132 }
1133
1134 // Place the load directly after the reshape to ensure that the shape memref
1135 // was not mutated.
1136 rewriter.setInsertionPointAfter(reshape);
1137 Location loc = dim.getLoc();
1138 Value load =
1139 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1140 if (load.getType() != dim.getType())
1141 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1142 rewriter.replaceOp(dim, load);
1143 return success();
1144 }
1145};
1146
1147} // namespace
1148
1149void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1150 MLIRContext *context) {
1151 results.add<DimOfMemRefReshape>(context);
1152}
1153
1154// ---------------------------------------------------------------------------
1155// DmaStartOp
1156// ---------------------------------------------------------------------------
1157
1158void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1159 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1160 ValueRange destIndices, Value numElements,
1161 Value tagMemRef, ValueRange tagIndices, Value stride,
1162 Value elementsPerStride) {
1163 result.addOperands(srcMemRef);
1164 result.addOperands(srcIndices);
1165 result.addOperands(destMemRef);
1166 result.addOperands(destIndices);
1167 result.addOperands({numElements, tagMemRef});
1168 result.addOperands(tagIndices);
1169 if (stride)
1170 result.addOperands({stride, elementsPerStride});
1171}
1172
1173void DmaStartOp::print(OpAsmPrinter &p) {
1174 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1175 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1176 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1177 if (isStrided())
1178 p << ", " << getStride() << ", " << getNumElementsPerStride();
1179
1180 p.printOptionalAttrDict((*this)->getAttrs());
1181 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1182 << ", " << getTagMemRef().getType();
1183}
1184
1185// Parse DmaStartOp.
1186// Ex:
1187// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1188// %tag[%index], %stride, %num_elt_per_stride :
1189// : memref<3076 x f32, 0>,
1190// memref<1024 x f32, 2>,
1191// memref<1 x i32>
1192//
1193ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1194 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1196 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1198 OpAsmParser::UnresolvedOperand numElementsInfo;
1199 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1202
1204 auto indexType = parser.getBuilder().getIndexType();
1205
1206 // Parse and resolve the following list of operands:
1207 // *) source memref followed by its indices (in square brackets).
1208 // *) destination memref followed by its indices (in square brackets).
1209 // *) dma size in KiB.
1210 if (parser.parseOperand(srcMemRefInfo) ||
1211 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1212 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1213 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1214 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1215 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1216 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1217 return failure();
1218
1219 // Parse optional stride and elements per stride.
1220 if (parser.parseTrailingOperandList(strideInfo))
1221 return failure();
1222
1223 bool isStrided = strideInfo.size() == 2;
1224 if (!strideInfo.empty() && !isStrided) {
1225 return parser.emitError(parser.getNameLoc(),
1226 "expected two stride related operands");
1227 }
1228
1229 if (parser.parseColonTypeList(types))
1230 return failure();
1231 if (types.size() != 3)
1232 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1233
1234 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1235 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1236 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1237 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1238 // size should be an index.
1239 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1240 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1241 // tag indices should be index.
1242 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1243 return failure();
1244
1245 if (isStrided) {
1246 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1247 return failure();
1248 }
1249
1250 return success();
1251}
1252
1253LogicalResult DmaStartOp::verify() {
1254 unsigned numOperands = getNumOperands();
1255
1256 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1257 // the number of elements.
1258 if (numOperands < 4)
1259 return emitOpError("expected at least 4 operands");
1260
1261 // Check types of operands. The order of these calls is important: the later
1262 // calls rely on some type properties to compute the operand position.
1263 // 1. Source memref.
1264 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1265 return emitOpError("expected source to be of memref type");
1266 if (numOperands < getSrcMemRefRank() + 4)
1267 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1268 << " operands";
1269 if (!getSrcIndices().empty() &&
1270 !llvm::all_of(getSrcIndices().getTypes(),
1271 [](Type t) { return t.isIndex(); }))
1272 return emitOpError("expected source indices to be of index type");
1273
1274 // 2. Destination memref.
1275 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1276 return emitOpError("expected destination to be of memref type");
1277 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1278 if (numOperands < numExpectedOperands)
1279 return emitOpError() << "expected at least " << numExpectedOperands
1280 << " operands";
1281 if (!getDstIndices().empty() &&
1282 !llvm::all_of(getDstIndices().getTypes(),
1283 [](Type t) { return t.isIndex(); }))
1284 return emitOpError("expected destination indices to be of index type");
1285
1286 // 3. Number of elements.
1287 if (!getNumElements().getType().isIndex())
1288 return emitOpError("expected num elements to be of index type");
1289
1290 // 4. Tag memref.
1291 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1292 return emitOpError("expected tag to be of memref type");
1293 numExpectedOperands += getTagMemRefRank();
1294 if (numOperands < numExpectedOperands)
1295 return emitOpError() << "expected at least " << numExpectedOperands
1296 << " operands";
1297 if (!getTagIndices().empty() &&
1298 !llvm::all_of(getTagIndices().getTypes(),
1299 [](Type t) { return t.isIndex(); }))
1300 return emitOpError("expected tag indices to be of index type");
1301
1302 // Optional stride-related operands must be either both present or both
1303 // absent.
1304 if (numOperands != numExpectedOperands &&
1305 numOperands != numExpectedOperands + 2)
1306 return emitOpError("incorrect number of operands");
1307
1308 // 5. Strides.
1309 if (isStrided()) {
1310 if (!getStride().getType().isIndex() ||
1311 !getNumElementsPerStride().getType().isIndex())
1312 return emitOpError(
1313 "expected stride and num elements per stride to be of type index");
1314 }
1315
1316 return success();
1317}
1318
1319LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1321 /// dma_start(memrefcast) -> dma_start
1322 return foldMemRefCast(*this);
1323}
1324
1325// ---------------------------------------------------------------------------
1326// DmaWaitOp
1327// ---------------------------------------------------------------------------
1328
1329LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1331 /// dma_wait(memrefcast) -> dma_wait
1332 return foldMemRefCast(*this);
1333}
1334
1335LogicalResult DmaWaitOp::verify() {
1336 // Check that the number of tag indices matches the tagMemRef rank.
1337 unsigned numTagIndices = getTagIndices().size();
1338 unsigned tagMemRefRank = getTagMemRefRank();
1339 if (numTagIndices != tagMemRefRank)
1340 return emitOpError() << "expected tagIndices to have the same number of "
1341 "elements as the tagMemRef rank, expected "
1342 << tagMemRefRank << ", but got " << numTagIndices;
1343 return success();
1344}
1345
1346//===----------------------------------------------------------------------===//
1347// ExtractAlignedPointerAsIndexOp
1348//===----------------------------------------------------------------------===//
1349
1350void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1351 function_ref<void(Value, StringRef)> setNameFn) {
1352 setNameFn(getResult(), "intptr");
1353}
1354
1355//===----------------------------------------------------------------------===//
1356// ExtractStridedMetadataOp
1357//===----------------------------------------------------------------------===//
1358
1359/// The number and type of the results are inferred from the
1360/// shape of the source.
1361LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1362 MLIRContext *context, std::optional<Location> location,
1363 ExtractStridedMetadataOp::Adaptor adaptor,
1364 SmallVectorImpl<Type> &inferredReturnTypes) {
1365 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1366 if (!sourceType)
1367 return failure();
1368
1369 unsigned sourceRank = sourceType.getRank();
1370 IndexType indexType = IndexType::get(context);
1371 auto memrefType =
1372 MemRefType::get({}, sourceType.getElementType(),
1373 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1374 // Base.
1375 inferredReturnTypes.push_back(memrefType);
1376 // Offset.
1377 inferredReturnTypes.push_back(indexType);
1378 // Sizes and strides.
1379 for (unsigned i = 0; i < sourceRank * 2; ++i)
1380 inferredReturnTypes.push_back(indexType);
1381 return success();
1382}
1383
1384void ExtractStridedMetadataOp::getAsmResultNames(
1385 function_ref<void(Value, StringRef)> setNameFn) {
1386 setNameFn(getBaseBuffer(), "base_buffer");
1387 setNameFn(getOffset(), "offset");
1388 // For multi-result to work properly with pretty names and packed syntax `x:3`
1389 // we can only give a pretty name to the first value in the pack.
1390 if (!getSizes().empty()) {
1391 setNameFn(getSizes().front(), "sizes");
1392 setNameFn(getStrides().front(), "strides");
1393 }
1394}
1395
1396/// Helper function to perform the replacement of all constant uses of `values`
1397/// by a materialized constant extracted from `maybeConstants`.
1398/// `values` and `maybeConstants` are expected to have the same size.
1399template <typename Container>
1400static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1401 Container values,
1402 ArrayRef<OpFoldResult> maybeConstants) {
1403 assert(values.size() == maybeConstants.size() &&
1404 " expected values and maybeConstants of the same size");
1405 bool atLeastOneReplacement = false;
1406 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1407 // Don't materialize a constant if there are no uses: this would indice
1408 // infinite loops in the driver.
1409 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1410 continue;
1411 assert(isa<Attribute>(maybeConstant) &&
1412 "The constified value should be either unchanged (i.e., == result) "
1413 "or a constant");
1415 rewriter, loc,
1416 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1417 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1418 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1419 // yet.
1420 op->replaceUsesOfWith(result, constantVal);
1421 atLeastOneReplacement = true;
1422 }
1423 }
1424 return atLeastOneReplacement;
1425}
1426
1427LogicalResult
1428ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1430 OpBuilder builder(*this);
1431
1432 bool atLeastOneReplacement = replaceConstantUsesOf(
1433 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1434 getConstifiedMixedOffset());
1435 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1436 getConstifiedMixedSizes());
1437 atLeastOneReplacement |= replaceConstantUsesOf(
1438 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1439
1440 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1441 if (auto prev = getSource().getDefiningOp<CastOp>())
1442 if (isa<MemRefType>(prev.getSource().getType())) {
1443 getSourceMutable().assign(prev.getSource());
1444 atLeastOneReplacement = true;
1445 }
1446
1447 return success(atLeastOneReplacement);
1448}
1449
1450SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1451 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1452 constifyIndexValues(values, getSource().getType().getShape());
1453 return values;
1454}
1455
1457ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1458 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1459 SmallVector<int64_t> staticValues;
1460 int64_t unused;
1461 LogicalResult status =
1462 getSource().getType().getStridesAndOffset(staticValues, unused);
1463 (void)status;
1464 assert(succeeded(status) && "could not get strides from type");
1465 constifyIndexValues(values, staticValues);
1466 return values;
1467}
1468
1469OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1470 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1471 SmallVector<OpFoldResult> values(1, offsetOfr);
1472 SmallVector<int64_t> staticValues, unused;
1473 int64_t offset;
1474 LogicalResult status =
1475 getSource().getType().getStridesAndOffset(unused, offset);
1476 (void)status;
1477 assert(succeeded(status) && "could not get offset from type");
1478 staticValues.push_back(offset);
1479 constifyIndexValues(values, staticValues);
1480 return values[0];
1481}
1482
1483//===----------------------------------------------------------------------===//
1484// GenericAtomicRMWOp
1485//===----------------------------------------------------------------------===//
1486
1487void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1488 Value memref, ValueRange ivs) {
1489 OpBuilder::InsertionGuard g(builder);
1490 result.addOperands(memref);
1491 result.addOperands(ivs);
1492
1493 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1494 Type elementType = memrefType.getElementType();
1495 result.addTypes(elementType);
1496
1497 Region *bodyRegion = result.addRegion();
1498 builder.createBlock(bodyRegion);
1499 bodyRegion->addArgument(elementType, memref.getLoc());
1500 }
1501}
1502
1503LogicalResult GenericAtomicRMWOp::verify() {
1504 auto &body = getRegion();
1505 if (body.getNumArguments() != 1)
1506 return emitOpError("expected single number of entry block arguments");
1507
1508 if (getResult().getType() != body.getArgument(0).getType())
1509 return emitOpError("expected block argument of the same type result type");
1510
1511 bool hasSideEffects =
1512 body.walk([&](Operation *nestedOp) {
1513 if (isMemoryEffectFree(nestedOp))
1514 return WalkResult::advance();
1515 nestedOp->emitError(
1516 "body of 'memref.generic_atomic_rmw' should contain "
1517 "only operations with no side effects");
1518 return WalkResult::interrupt();
1519 })
1520 .wasInterrupted();
1521 return hasSideEffects ? failure() : success();
1522}
1523
1524ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1527 Type memrefType;
1529
1530 Type indexType = parser.getBuilder().getIndexType();
1531 if (parser.parseOperand(memref) ||
1533 parser.parseColonType(memrefType) ||
1534 parser.resolveOperand(memref, memrefType, result.operands) ||
1535 parser.resolveOperands(ivs, indexType, result.operands))
1536 return failure();
1537
1538 Region *body = result.addRegion();
1539 if (parser.parseRegion(*body, {}) ||
1540 parser.parseOptionalAttrDict(result.attributes))
1541 return failure();
1542 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1543 return success();
1544}
1545
1546void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1547 p << ' ' << getMemref() << "[" << getIndices()
1548 << "] : " << getMemref().getType() << ' ';
1549 p.printRegion(getRegion());
1550 p.printOptionalAttrDict((*this)->getAttrs());
1551}
1552
1553//===----------------------------------------------------------------------===//
1554// AtomicYieldOp
1555//===----------------------------------------------------------------------===//
1556
1557LogicalResult AtomicYieldOp::verify() {
1558 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1559 Type resultType = getResult().getType();
1560 if (parentType != resultType)
1561 return emitOpError() << "types mismatch between yield op: " << resultType
1562 << " and its parent: " << parentType;
1563 return success();
1564}
1565
1566//===----------------------------------------------------------------------===//
1567// GlobalOp
1568//===----------------------------------------------------------------------===//
1569
1571 TypeAttr type,
1572 Attribute initialValue) {
1573 p << type;
1574 if (!op.isExternal()) {
1575 p << " = ";
1576 if (op.isUninitialized())
1577 p << "uninitialized";
1578 else
1579 p.printAttributeWithoutType(initialValue);
1580 }
1581}
1582
1583static ParseResult
1585 Attribute &initialValue) {
1586 Type type;
1587 if (parser.parseType(type))
1588 return failure();
1589
1590 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1591 if (!memrefType || !memrefType.hasStaticShape())
1592 return parser.emitError(parser.getNameLoc())
1593 << "type should be static shaped memref, but got " << type;
1594 typeAttr = TypeAttr::get(type);
1595
1596 if (parser.parseOptionalEqual())
1597 return success();
1598
1599 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1600 initialValue = UnitAttr::get(parser.getContext());
1601 return success();
1602 }
1603
1604 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1605 if (parser.parseAttribute(initialValue, tensorType))
1606 return failure();
1607 if (!llvm::isa<ElementsAttr>(initialValue))
1608 return parser.emitError(parser.getNameLoc())
1609 << "initial value should be a unit or elements attribute";
1610 return success();
1611}
1612
1613LogicalResult GlobalOp::verify() {
1614 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1615 if (!memrefType || !memrefType.hasStaticShape())
1616 return emitOpError("type should be static shaped memref, but got ")
1617 << getType();
1618
1619 // Verify that the initial value, if present, is either a unit attribute or
1620 // an elements attribute.
1621 if (getInitialValue().has_value()) {
1622 Attribute initValue = getInitialValue().value();
1623 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1624 return emitOpError("initial value should be a unit or elements "
1625 "attribute, but got ")
1626 << initValue;
1627
1628 // Check that the type of the initial value is compatible with the type of
1629 // the global variable.
1630 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1631 // Check the element types match.
1632 auto initElementType =
1633 cast<TensorType>(elementsAttr.getType()).getElementType();
1634 auto memrefElementType = memrefType.getElementType();
1635
1636 if (initElementType != memrefElementType)
1637 return emitOpError("initial value element expected to be of type ")
1638 << memrefElementType << ", but was of type " << initElementType;
1639
1640 // Check the shapes match, given that memref globals can only produce
1641 // statically shaped memrefs and elements literal type must have a static
1642 // shape we can assume both types are shaped.
1643 auto initShape = elementsAttr.getShapedType().getShape();
1644 auto memrefShape = memrefType.getShape();
1645 if (initShape != memrefShape)
1646 return emitOpError("initial value shape expected to be ")
1647 << memrefShape << " but was " << initShape;
1648 }
1649 }
1650
1651 // TODO: verify visibility for declarations.
1652 return success();
1653}
1654
1655ElementsAttr GlobalOp::getConstantInitValue() {
1656 auto initVal = getInitialValue();
1657 if (getConstant() && initVal.has_value())
1658 return llvm::cast<ElementsAttr>(initVal.value());
1659 return {};
1660}
1661
1662//===----------------------------------------------------------------------===//
1663// GetGlobalOp
1664//===----------------------------------------------------------------------===//
1665
1666LogicalResult
1667GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1668 // Verify that the result type is same as the type of the referenced
1669 // memref.global op.
1670 auto global =
1671 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1672 if (!global)
1673 return emitOpError("'")
1674 << getName() << "' does not reference a valid global memref";
1675
1676 Type resultType = getResult().getType();
1677 if (global.getType() != resultType)
1678 return emitOpError("result type ")
1679 << resultType << " does not match type " << global.getType()
1680 << " of the global memref @" << getName();
1681 return success();
1682}
1683
1684//===----------------------------------------------------------------------===//
1685// LoadOp
1686//===----------------------------------------------------------------------===//
1687
1688LogicalResult LoadOp::verify() {
1689 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1690 return emitOpError("incorrect number of indices for load, expected ")
1691 << getMemRefType().getRank() << " but got " << getIndices().size();
1692 }
1693 return success();
1694}
1695
1696OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1697 /// load(memrefcast) -> load
1698 if (succeeded(foldMemRefCast(*this)))
1699 return getResult();
1700 return OpFoldResult();
1701}
1702
1703FailureOr<std::optional<SmallVector<Value>>>
1704LoadOp::bubbleDownCasts(OpBuilder &builder) {
1706 getResult());
1707}
1708
1709//===----------------------------------------------------------------------===//
1710// MemorySpaceCastOp
1711//===----------------------------------------------------------------------===//
1712
1713void MemorySpaceCastOp::getAsmResultNames(
1714 function_ref<void(Value, StringRef)> setNameFn) {
1715 setNameFn(getResult(), "memspacecast");
1716}
1717
1718bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1719 if (inputs.size() != 1 || outputs.size() != 1)
1720 return false;
1721 Type a = inputs.front(), b = outputs.front();
1722 auto aT = llvm::dyn_cast<MemRefType>(a);
1723 auto bT = llvm::dyn_cast<MemRefType>(b);
1724
1725 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1726 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1727
1728 if (aT && bT) {
1729 if (aT.getElementType() != bT.getElementType())
1730 return false;
1731 if (aT.getLayout() != bT.getLayout())
1732 return false;
1733 if (aT.getShape() != bT.getShape())
1734 return false;
1735 return true;
1736 }
1737 if (uaT && ubT) {
1738 return uaT.getElementType() == ubT.getElementType();
1739 }
1740 return false;
1741}
1742
1743OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1744 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1745 // t2)
1746 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1747 getSourceMutable().assign(parentCast.getSource());
1748 return getResult();
1749 }
1750 return Value{};
1751}
1752
1753TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1754 return getSource();
1755}
1756
1757TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1758 return getDest();
1759}
1760
1761bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1762 PtrLikeTypeInterface src) {
1763 return isa<BaseMemRefType>(tgt) &&
1764 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1765}
1766
1767MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1768 OpBuilder &b, PtrLikeTypeInterface tgt,
1770 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1771 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1772}
1773
1774/// The only cast we recognize as promotable is to the generic space.
1775bool MemorySpaceCastOp::isSourcePromotable() {
1776 return getDest().getType().getMemorySpace() == nullptr;
1777}
1778
1779//===----------------------------------------------------------------------===//
1780// PrefetchOp
1781//===----------------------------------------------------------------------===//
1782
1783void PrefetchOp::print(OpAsmPrinter &p) {
1784 p << " " << getMemref() << '[';
1786 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1787 p << ", locality<" << getLocalityHint();
1788 p << ">, " << (getIsDataCache() ? "data" : "instr");
1790 (*this)->getAttrs(),
1791 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1792 p << " : " << getMemRefType();
1793}
1794
1795ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1798 IntegerAttr localityHint;
1799 MemRefType type;
1800 StringRef readOrWrite, cacheType;
1801
1802 auto indexTy = parser.getBuilder().getIndexType();
1803 auto i32Type = parser.getBuilder().getIntegerType(32);
1804 if (parser.parseOperand(memrefInfo) ||
1806 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1807 parser.parseComma() || parser.parseKeyword("locality") ||
1808 parser.parseLess() ||
1809 parser.parseAttribute(localityHint, i32Type, "localityHint",
1810 result.attributes) ||
1811 parser.parseGreater() || parser.parseComma() ||
1812 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1813 parser.resolveOperand(memrefInfo, type, result.operands) ||
1814 parser.resolveOperands(indexInfo, indexTy, result.operands))
1815 return failure();
1816
1817 if (readOrWrite != "read" && readOrWrite != "write")
1818 return parser.emitError(parser.getNameLoc(),
1819 "rw specifier has to be 'read' or 'write'");
1820 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1821 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1822
1823 if (cacheType != "data" && cacheType != "instr")
1824 return parser.emitError(parser.getNameLoc(),
1825 "cache type has to be 'data' or 'instr'");
1826
1827 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1828 parser.getBuilder().getBoolAttr(cacheType == "data"));
1829
1830 return success();
1831}
1832
1833LogicalResult PrefetchOp::verify() {
1834 if (getNumOperands() != 1 + getMemRefType().getRank())
1835 return emitOpError("too few indices");
1836
1837 return success();
1838}
1839
1840LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1842 // prefetch(memrefcast) -> prefetch
1843 return foldMemRefCast(*this);
1844}
1845
1846//===----------------------------------------------------------------------===//
1847// RankOp
1848//===----------------------------------------------------------------------===//
1849
1850OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1851 // Constant fold rank when the rank of the operand is known.
1852 auto type = getOperand().getType();
1853 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1854 if (shapedType && shapedType.hasRank())
1855 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1856 return IntegerAttr();
1857}
1858
1859//===----------------------------------------------------------------------===//
1860// ReinterpretCastOp
1861//===----------------------------------------------------------------------===//
1862
1863void ReinterpretCastOp::getAsmResultNames(
1864 function_ref<void(Value, StringRef)> setNameFn) {
1865 setNameFn(getResult(), "reinterpret_cast");
1866}
1867
1868/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1869/// `staticSizes` and `staticStrides` are automatically filled with
1870/// source-memref-rank sentinel values that encode dynamic entries.
1871void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1872 MemRefType resultType, Value source,
1874 ArrayRef<OpFoldResult> strides,
1876 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1877 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1878 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1879 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1880 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1881 result.addAttributes(attrs);
1882 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1883 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1884 b.getDenseI64ArrayAttr(staticSizes),
1885 b.getDenseI64ArrayAttr(staticStrides));
1886}
1887
1888void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1889 Value source, OpFoldResult offset,
1891 ArrayRef<OpFoldResult> strides,
1893 auto sourceType = cast<BaseMemRefType>(source.getType());
1894 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1895 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1896 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1897 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1898 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1899 auto stridedLayout = StridedLayoutAttr::get(
1900 b.getContext(), staticOffsets.front(), staticStrides);
1901 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1902 stridedLayout, sourceType.getMemorySpace());
1903 build(b, result, resultType, source, offset, sizes, strides, attrs);
1904}
1905
1906void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1907 MemRefType resultType, Value source,
1908 int64_t offset, ArrayRef<int64_t> sizes,
1909 ArrayRef<int64_t> strides,
1911 SmallVector<OpFoldResult> sizeValues =
1912 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1913 return b.getI64IntegerAttr(v);
1914 }));
1915 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1916 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1917 return b.getI64IntegerAttr(v);
1918 }));
1919 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1920 strideValues, attrs);
1921}
1922
1923void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1924 MemRefType resultType, Value source, Value offset,
1925 ValueRange sizes, ValueRange strides,
1927 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1928 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1929 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1930 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1931 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1932}
1933
1934// TODO: ponder whether we want to allow missing trailing sizes/strides that are
1935// completed automatically, like we have for subview and extract_slice.
1936LogicalResult ReinterpretCastOp::verify() {
1937 // The source and result memrefs should be in the same memory space.
1938 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1939 auto resultType = llvm::cast<MemRefType>(getType());
1940 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1941 return emitError("different memory spaces specified for source type ")
1942 << srcType << " and result memref type " << resultType;
1943 if (srcType.getElementType() != resultType.getElementType())
1944 return emitError("different element types specified for source type ")
1945 << srcType << " and result memref type " << resultType;
1946
1947 // Match sizes in result memref type and in static_sizes attribute.
1948 for (auto [idx, resultSize, expectedSize] :
1949 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1950 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1951 return emitError("expected result type with size = ")
1952 << (ShapedType::isDynamic(expectedSize)
1953 ? std::string("dynamic")
1954 : std::to_string(expectedSize))
1955 << " instead of " << resultSize << " in dim = " << idx;
1956 }
1957
1958 // Match offset and strides in static_offset and static_strides attributes. If
1959 // result memref type has no affine map specified, this will assume an
1960 // identity layout.
1961 int64_t resultOffset;
1962 SmallVector<int64_t, 4> resultStrides;
1963 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1964 return emitError("expected result type to have strided layout but found ")
1965 << resultType;
1966
1967 // Match offset in result memref type and in static_offsets attribute.
1968 int64_t expectedOffset = getStaticOffsets().front();
1969 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1970 return emitError("expected result type with offset = ")
1971 << (ShapedType::isDynamic(expectedOffset)
1972 ? std::string("dynamic")
1973 : std::to_string(expectedOffset))
1974 << " instead of " << resultOffset;
1975
1976 // Match strides in result memref type and in static_strides attribute.
1977 for (auto [idx, resultStride, expectedStride] :
1978 llvm::enumerate(resultStrides, getStaticStrides())) {
1979 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1980 return emitError("expected result type with stride = ")
1981 << (ShapedType::isDynamic(expectedStride)
1982 ? std::string("dynamic")
1983 : std::to_string(expectedStride))
1984 << " instead of " << resultStride << " in dim = " << idx;
1985 }
1986
1987 return success();
1988}
1989
1990OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1991 Value src = getSource();
1992 auto getPrevSrc = [&]() -> Value {
1993 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1994 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1995 return prev.getSource();
1996
1997 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1998 if (auto prev = src.getDefiningOp<CastOp>())
1999 return prev.getSource();
2000
2001 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
2002 // are 0.
2003 if (auto prev = src.getDefiningOp<SubViewOp>())
2004 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
2005 return prev.getSource();
2006
2007 return nullptr;
2008 };
2009
2010 if (auto prevSrc = getPrevSrc()) {
2011 getSourceMutable().assign(prevSrc);
2012 return getResult();
2013 }
2014
2015 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2016 if (ShapedType::isStaticShape(getType().getShape()) &&
2017 src.getType() == getType() && getStaticOffsets().front() == 0) {
2018 return src;
2019 }
2020
2021 return nullptr;
2022}
2023
2024SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2027 return values;
2028}
2029
2030SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2031 SmallVector<OpFoldResult> values = getMixedStrides();
2032 SmallVector<int64_t> staticValues;
2033 int64_t unused;
2034 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2035 (void)status;
2036 assert(succeeded(status) && "could not get strides from type");
2037 constifyIndexValues(values, staticValues);
2038 return values;
2039}
2040
2041OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2042 SmallVector<OpFoldResult> values = getMixedOffsets();
2043 assert(values.size() == 1 &&
2044 "reinterpret_cast must have one and only one offset");
2045 SmallVector<int64_t> staticValues, unused;
2046 int64_t offset;
2047 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2048 (void)status;
2049 assert(succeeded(status) && "could not get offset from type");
2050 staticValues.push_back(offset);
2051 constifyIndexValues(values, staticValues);
2052 return values[0];
2053}
2054
2055namespace {
2056/// Replace the sequence:
2057/// ```
2058/// base, offset, sizes, strides = extract_strided_metadata src
2059/// dst = reinterpret_cast base to offset, sizes, strides
2060/// ```
2061/// With
2062///
2063/// ```
2064/// dst = memref.cast src
2065/// ```
2066///
2067/// Note: The cast operation is only inserted when the type of dst and src
2068/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2069///
2070/// This pattern also matches when the offset, sizes, and strides don't come
2071/// directly from the `extract_strided_metadata`'s results but it can be
2072/// statically proven that they would hold the same values.
2073///
2074/// For instance, the following sequence would be replaced:
2075/// ```
2076/// base, offset, sizes, strides =
2077/// extract_strided_metadata memref : memref<3x4xty>
2078/// dst = reinterpret_cast base to 0, [3, 4], strides
2079/// ```
2080/// Because we know (thanks to the type of the input memref) that variable
2081/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2082///
2083/// Similarly, the following sequence would be replaced:
2084/// ```
2085/// c0 = arith.constant 0
2086/// c4 = arith.constant 4
2087/// base, offset, sizes, strides =
2088/// extract_strided_metadata memref : memref<3x4xty>
2089/// dst = reinterpret_cast base to c0, [3, c4], strides
2090/// ```
2091/// Because we know that `offset`and `c0` will hold 0
2092/// and `c4` will hold 4.
2093///
2094/// If the pattern above does not match, the input of the
2095/// extract_strided_metadata is always folded into the input of the
2096/// reinterpret_cast operator. This allows for dead code elimination to get rid
2097/// of the extract_strided_metadata in some cases.
2098struct ReinterpretCastOpExtractStridedMetadataFolder
2099 : public OpRewritePattern<ReinterpretCastOp> {
2100public:
2101 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2102
2103 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2104 PatternRewriter &rewriter) const override {
2105 auto extractStridedMetadata =
2106 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2107 if (!extractStridedMetadata)
2108 return failure();
2109
2110 // Check if the reinterpret cast reconstructs a memref with the exact same
2111 // properties as the extract strided metadata.
2112 auto isReinterpretCastNoop = [&]() -> bool {
2113 // First, check that the strides are the same.
2114 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2115 op.getConstifiedMixedStrides()))
2116 return false;
2117
2118 // Second, check the sizes.
2119 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2120 op.getConstifiedMixedSizes()))
2121 return false;
2122
2123 // Finally, check the offset.
2124 assert(op.getMixedOffsets().size() == 1 &&
2125 "reinterpret_cast with more than one offset should have been "
2126 "rejected by the verifier");
2127 return extractStridedMetadata.getConstifiedMixedOffset() ==
2128 op.getConstifiedMixedOffset();
2129 };
2130
2131 if (!isReinterpretCastNoop()) {
2132 // If the extract_strided_metadata / reinterpret_cast pair can't be
2133 // completely folded, then we could fold the input of the
2134 // extract_strided_metadata into the input of the reinterpret_cast
2135 // input. For some cases (e.g., static dimensions) the
2136 // the extract_strided_metadata is eliminated by dead code elimination.
2137 //
2138 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2139 //
2140 // We can always fold the input of a extract_strided_metadata operator
2141 // to the input of a reinterpret_cast operator, because they point to
2142 // the same memory. Note that the reinterpret_cast does not use the
2143 // layout of its input memref, only its base memory pointer which is
2144 // the same as the base pointer returned by the extract_strided_metadata
2145 // operator and the base pointer of the extract_strided_metadata memref
2146 // input.
2147 rewriter.modifyOpInPlace(op, [&]() {
2148 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2149 });
2150 return success();
2151 }
2152
2153 // At this point, we know that the back and forth between extract strided
2154 // metadata and reinterpret cast is a noop. However, the final type of the
2155 // reinterpret cast may not be exactly the same as the original memref.
2156 // E.g., it could be changing a dimension from static to dynamic. Check that
2157 // here and add a cast if necessary.
2158 Type srcTy = extractStridedMetadata.getSource().getType();
2159 if (srcTy == op.getResult().getType())
2160 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2161 else
2162 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2163 extractStridedMetadata.getSource());
2164
2165 return success();
2166 }
2167};
2168
2169struct ReinterpretCastOpConstantFolder
2170 : public OpRewritePattern<ReinterpretCastOp> {
2171public:
2172 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2173
2174 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2175 PatternRewriter &rewriter) const override {
2176 unsigned srcStaticCount = llvm::count_if(
2177 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2178 op.getMixedStrides()),
2179 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2180
2181 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2182 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2183 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2184
2185 // TODO: Using counting comparison instead of direct comparison because
2186 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2187 // IntegerAttrs, while constifyIndexValues (and therefore
2188 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2189 if (srcStaticCount ==
2190 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2191 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2192 return failure();
2193
2194 auto newReinterpretCast = ReinterpretCastOp::create(
2195 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2196
2197 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2198 return success();
2199 }
2200};
2201} // namespace
2202
2203void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2204 MLIRContext *context) {
2205 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2206 ReinterpretCastOpConstantFolder>(context);
2207}
2208
2209FailureOr<std::optional<SmallVector<Value>>>
2210ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2211 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2212}
2213
2214//===----------------------------------------------------------------------===//
2215// Reassociative reshape ops
2216//===----------------------------------------------------------------------===//
2217
2218void CollapseShapeOp::getAsmResultNames(
2219 function_ref<void(Value, StringRef)> setNameFn) {
2220 setNameFn(getResult(), "collapse_shape");
2221}
2222
2223void ExpandShapeOp::getAsmResultNames(
2224 function_ref<void(Value, StringRef)> setNameFn) {
2225 setNameFn(getResult(), "expand_shape");
2226}
2227
2228LogicalResult ExpandShapeOp::reifyResultShapes(
2229 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2230 reifiedResultShapes = {
2231 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2232 return success();
2233}
2234
2235/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2236/// result and operand. Layout maps are verified separately.
2237///
2238/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2239/// allowed in a reassocation group.
2240static LogicalResult
2242 ArrayRef<int64_t> expandedShape,
2243 ArrayRef<ReassociationIndices> reassociation,
2244 bool allowMultipleDynamicDimsPerGroup) {
2245 // There must be one reassociation group per collapsed dimension.
2246 if (collapsedShape.size() != reassociation.size())
2247 return op->emitOpError("invalid number of reassociation groups: found ")
2248 << reassociation.size() << ", expected " << collapsedShape.size();
2249
2250 // The next expected expanded dimension index (while iterating over
2251 // reassociation indices).
2252 int64_t nextDim = 0;
2253 for (const auto &it : llvm::enumerate(reassociation)) {
2254 ReassociationIndices group = it.value();
2255 int64_t collapsedDim = it.index();
2256
2257 bool foundDynamic = false;
2258 for (int64_t expandedDim : group) {
2259 if (expandedDim != nextDim++)
2260 return op->emitOpError("reassociation indices must be contiguous");
2261
2262 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2263 return op->emitOpError("reassociation index ")
2264 << expandedDim << " is out of bounds";
2265
2266 // Check if there are multiple dynamic dims in a reassociation group.
2267 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2268 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2269 return op->emitOpError(
2270 "at most one dimension in a reassociation group may be dynamic");
2271 foundDynamic = true;
2272 }
2273 }
2274
2275 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2276 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2277 return op->emitOpError("collapsed dim (")
2278 << collapsedDim
2279 << ") must be dynamic if and only if reassociation group is "
2280 "dynamic";
2281
2282 // If all dims in the reassociation group are static, the size of the
2283 // collapsed dim can be verified.
2284 if (!foundDynamic) {
2285 int64_t groupSize = 1;
2286 for (int64_t expandedDim : group)
2287 groupSize *= expandedShape[expandedDim];
2288 if (groupSize != collapsedShape[collapsedDim])
2289 return op->emitOpError("collapsed dim size (")
2290 << collapsedShape[collapsedDim]
2291 << ") must equal reassociation group size (" << groupSize << ")";
2292 }
2293 }
2294
2295 if (collapsedShape.empty()) {
2296 // Rank 0: All expanded dimensions must be 1.
2297 for (int64_t d : expandedShape)
2298 if (d != 1)
2299 return op->emitOpError(
2300 "rank 0 memrefs can only be extended/collapsed with/from ones");
2301 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2302 // Rank >= 1: Number of dimensions among all reassociation groups must match
2303 // the result memref rank.
2304 return op->emitOpError("expanded rank (")
2305 << expandedShape.size()
2306 << ") inconsistent with number of reassociation indices (" << nextDim
2307 << ")";
2308 }
2309
2310 return success();
2311}
2312
2313SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2314 return getSymbolLessAffineMaps(getReassociationExprs());
2315}
2316
2317SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2319 getReassociationIndices());
2320}
2321
2322SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2323 return getSymbolLessAffineMaps(getReassociationExprs());
2324}
2325
2326SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2328 getReassociationIndices());
2329}
2330
2331/// Compute the layout map after expanding a given source MemRef type with the
2332/// specified reassociation indices.
2333static FailureOr<StridedLayoutAttr>
2334computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2335 ArrayRef<ReassociationIndices> reassociation) {
2336 int64_t srcOffset;
2337 SmallVector<int64_t> srcStrides;
2338 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2339 return failure();
2340 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2341
2342 // 1-1 mapping between srcStrides and reassociation packs.
2343 // Each srcStride starts with the given value and gets expanded according to
2344 // the proper entries in resultShape.
2345 // Example:
2346 // srcStrides = [10000, 1 , 100 ],
2347 // reassociations = [ [0], [1], [2, 3, 4]],
2348 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2349 // -> For the purpose of stride calculation, the useful sizes are:
2350 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2351 // resultStrides = [10000, 1, 600, 200, 100]
2352 // Note that a stride does not get expanded along the first entry of each
2353 // shape pack.
2354 SmallVector<int64_t> reverseResultStrides;
2355 reverseResultStrides.reserve(resultShape.size());
2356 unsigned shapeIndex = resultShape.size() - 1;
2357 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2358 ReassociationIndices reassoc = std::get<0>(it);
2359 int64_t currentStrideToExpand = std::get<1>(it);
2360 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2361 reverseResultStrides.push_back(currentStrideToExpand);
2362 currentStrideToExpand =
2363 (SaturatedInteger::wrap(currentStrideToExpand) *
2364 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2365 .asInteger();
2366 }
2367 }
2368 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2369 resultStrides.resize(resultShape.size(), 1);
2370 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2371}
2372
2373FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2374 MemRefType srcType, ArrayRef<int64_t> resultShape,
2375 ArrayRef<ReassociationIndices> reassociation) {
2376 if (srcType.getLayout().isIdentity()) {
2377 // If the source is contiguous (i.e., no layout map specified), so is the
2378 // result.
2379 MemRefLayoutAttrInterface layout;
2380 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2381 srcType.getMemorySpace());
2382 }
2383
2384 // Source may not be contiguous. Compute the layout map.
2385 FailureOr<StridedLayoutAttr> computedLayout =
2386 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2387 if (failed(computedLayout))
2388 return failure();
2389 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2390 srcType.getMemorySpace());
2391}
2392
2393FailureOr<SmallVector<OpFoldResult>>
2394ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2395 MemRefType expandedType,
2396 ArrayRef<ReassociationIndices> reassociation,
2397 ArrayRef<OpFoldResult> inputShape) {
2398 std::optional<SmallVector<OpFoldResult>> outputShape =
2399 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2400 inputShape);
2401 if (!outputShape)
2402 return failure();
2403 return *outputShape;
2404}
2405
2406void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2407 Type resultType, Value src,
2408 ArrayRef<ReassociationIndices> reassociation,
2409 ArrayRef<OpFoldResult> outputShape) {
2410 auto [staticOutputShape, dynamicOutputShape] =
2411 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2412 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2413 getReassociationIndicesAttribute(builder, reassociation),
2414 dynamicOutputShape, staticOutputShape);
2415}
2416
2417void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2418 Type resultType, Value src,
2419 ArrayRef<ReassociationIndices> reassociation) {
2420 SmallVector<OpFoldResult> inputShape =
2421 getMixedSizes(builder, result.location, src);
2422 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2423 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2424 builder, result.location, memrefResultTy, reassociation, inputShape);
2425 // Failure of this assertion usually indicates presence of multiple
2426 // dynamic dimensions in the same reassociation group.
2427 assert(succeeded(outputShape) && "unable to infer output shape");
2428 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2429}
2430
2431void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2432 ArrayRef<int64_t> resultShape, Value src,
2433 ArrayRef<ReassociationIndices> reassociation) {
2434 // Only ranked memref source values are supported.
2435 auto srcType = llvm::cast<MemRefType>(src.getType());
2436 FailureOr<MemRefType> resultType =
2437 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2438 // Failure of this assertion usually indicates a problem with the source
2439 // type, e.g., could not get strides/offset.
2440 assert(succeeded(resultType) && "could not compute layout");
2441 build(builder, result, *resultType, src, reassociation);
2442}
2443
2444void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2445 ArrayRef<int64_t> resultShape, Value src,
2446 ArrayRef<ReassociationIndices> reassociation,
2447 ArrayRef<OpFoldResult> outputShape) {
2448 // Only ranked memref source values are supported.
2449 auto srcType = llvm::cast<MemRefType>(src.getType());
2450 FailureOr<MemRefType> resultType =
2451 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2452 // Failure of this assertion usually indicates a problem with the source
2453 // type, e.g., could not get strides/offset.
2454 assert(succeeded(resultType) && "could not compute layout");
2455 build(builder, result, *resultType, src, reassociation, outputShape);
2456}
2457
2458LogicalResult ExpandShapeOp::verify() {
2459 MemRefType srcType = getSrcType();
2460 MemRefType resultType = getResultType();
2461
2462 if (srcType.getRank() > resultType.getRank()) {
2463 auto r0 = srcType.getRank();
2464 auto r1 = resultType.getRank();
2465 return emitOpError("has source rank ")
2466 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2467 << r0 << " > " << r1 << ").";
2468 }
2469
2470 // Verify result shape.
2471 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2472 resultType.getShape(),
2473 getReassociationIndices(),
2474 /*allowMultipleDynamicDimsPerGroup=*/true)))
2475 return failure();
2476
2477 // Compute expected result type (including layout map).
2478 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2479 srcType, resultType.getShape(), getReassociationIndices());
2480 if (failed(expectedResultType))
2481 return emitOpError("invalid source layout map");
2482
2483 // Check actual result type.
2484 if (*expectedResultType != resultType)
2485 return emitOpError("expected expanded type to be ")
2486 << *expectedResultType << " but found " << resultType;
2487
2488 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2489 return emitOpError("expected number of static shape bounds to be equal to "
2490 "the output rank (")
2491 << resultType.getRank() << ") but found "
2492 << getStaticOutputShape().size() << " inputs instead";
2493
2494 if ((int64_t)getOutputShape().size() !=
2495 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2496 return emitOpError("mismatch in dynamic dims in output_shape and "
2497 "static_output_shape: static_output_shape has ")
2498 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2499 << " dynamic dims while output_shape has " << getOutputShape().size()
2500 << " values";
2501
2502 // Verify if provided output shapes are in agreement with output type.
2503 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2504 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2505 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2506 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2507 return emitOpError("invalid output shape provided at pos ") << pos;
2508 }
2509 }
2510
2511 return success();
2512}
2513
2514void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2515 MLIRContext *context) {
2516 results.add<
2517 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2518 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
2519}
2520
2521FailureOr<std::optional<SmallVector<Value>>>
2522ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2523 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2524}
2525
2526/// Compute the layout map after collapsing a given source MemRef type with the
2527/// specified reassociation indices.
2528///
2529/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2530/// not possible to check this by inspecting a MemRefType in the general case.
2531/// If non-contiguity cannot be checked statically, the collapse is assumed to
2532/// be valid (and thus accepted by this function) unless `strict = true`.
2533static FailureOr<StridedLayoutAttr>
2534computeCollapsedLayoutMap(MemRefType srcType,
2535 ArrayRef<ReassociationIndices> reassociation,
2536 bool strict = false) {
2537 int64_t srcOffset;
2538 SmallVector<int64_t> srcStrides;
2539 auto srcShape = srcType.getShape();
2540 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2541 return failure();
2542
2543 // The result stride of a reassociation group is the stride of the last entry
2544 // of the reassociation. (TODO: Should be the minimum stride in the
2545 // reassociation because strides are not necessarily sorted. E.g., when using
2546 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2547 // strides are meaningless and could have any arbitrary value.
2548 SmallVector<int64_t> resultStrides;
2549 resultStrides.reserve(reassociation.size());
2550 for (const ReassociationIndices &reassoc : reassociation) {
2551 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2552 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2553 ref = ref.drop_back();
2554 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2555 resultStrides.push_back(srcStrides[ref.back()]);
2556 } else {
2557 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2558 // the corresponding stride may have to be skipped. (See above comment.)
2559 // Therefore, the result stride cannot be statically determined and must
2560 // be dynamic.
2561 resultStrides.push_back(ShapedType::kDynamic);
2562 }
2563 }
2564
2565 // Validate that each reassociation group is contiguous.
2566 unsigned resultStrideIndex = resultStrides.size() - 1;
2567 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2568 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2569 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2570 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2571 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2572
2573 // Both source and result stride must have the same static value. In that
2574 // case, we can be sure, that the dimensions are collapsible (because they
2575 // are contiguous).
2576 // If `strict = false` (default during op verification), we accept cases
2577 // where one or both strides are dynamic. This is best effort: We reject
2578 // ops where obviously non-contiguous dims are collapsed, but accept ops
2579 // where we cannot be sure statically. Such ops may fail at runtime. See
2580 // the op documentation for details.
2581 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2582 if (strict && (stride.saturated || srcStride.saturated))
2583 return failure();
2584
2585 // Dimensions of size 1 should be skipped, because their strides are
2586 // meaningless and could have any arbitrary value.
2587 if (srcShape[idx - 1] == 1)
2588 continue;
2589
2590 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2591 return failure();
2592 }
2593 }
2594 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2595}
2596
2597bool CollapseShapeOp::isGuaranteedCollapsible(
2598 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2599 // MemRefs with identity layout are always collapsible.
2600 if (srcType.getLayout().isIdentity())
2601 return true;
2602
2603 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2604 /*strict=*/true));
2605}
2606
2607MemRefType CollapseShapeOp::computeCollapsedType(
2608 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2609 SmallVector<int64_t> resultShape;
2610 resultShape.reserve(reassociation.size());
2611 for (const ReassociationIndices &group : reassociation) {
2612 auto groupSize = SaturatedInteger::wrap(1);
2613 for (int64_t srcDim : group)
2614 groupSize =
2615 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2616 resultShape.push_back(groupSize.asInteger());
2617 }
2618
2619 if (srcType.getLayout().isIdentity()) {
2620 // If the source is contiguous (i.e., no layout map specified), so is the
2621 // result.
2622 MemRefLayoutAttrInterface layout;
2623 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2624 srcType.getMemorySpace());
2625 }
2626
2627 // Source may not be fully contiguous. Compute the layout map.
2628 // Note: Dimensions that are collapsed into a single dim are assumed to be
2629 // contiguous.
2630 FailureOr<StridedLayoutAttr> computedLayout =
2631 computeCollapsedLayoutMap(srcType, reassociation);
2632 assert(succeeded(computedLayout) &&
2633 "invalid source layout map or collapsing non-contiguous dims");
2634 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2635 srcType.getMemorySpace());
2636}
2637
2638void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2639 ArrayRef<ReassociationIndices> reassociation,
2640 ArrayRef<NamedAttribute> attrs) {
2641 auto srcType = llvm::cast<MemRefType>(src.getType());
2642 MemRefType resultType =
2643 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2645 getReassociationIndicesAttribute(b, reassociation));
2646 build(b, result, resultType, src, attrs);
2647}
2648
2649LogicalResult CollapseShapeOp::verify() {
2650 MemRefType srcType = getSrcType();
2651 MemRefType resultType = getResultType();
2652
2653 if (srcType.getRank() < resultType.getRank()) {
2654 auto r0 = srcType.getRank();
2655 auto r1 = resultType.getRank();
2656 return emitOpError("has source rank ")
2657 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2658 << r0 << " < " << r1 << ").";
2659 }
2660
2661 // Verify result shape.
2662 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2663 srcType.getShape(), getReassociationIndices(),
2664 /*allowMultipleDynamicDimsPerGroup=*/true)))
2665 return failure();
2666
2667 // Compute expected result type (including layout map).
2668 MemRefType expectedResultType;
2669 if (srcType.getLayout().isIdentity()) {
2670 // If the source is contiguous (i.e., no layout map specified), so is the
2671 // result.
2672 MemRefLayoutAttrInterface layout;
2673 expectedResultType =
2674 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2675 srcType.getMemorySpace());
2676 } else {
2677 // Source may not be fully contiguous. Compute the layout map.
2678 // Note: Dimensions that are collapsed into a single dim are assumed to be
2679 // contiguous.
2680 FailureOr<StridedLayoutAttr> computedLayout =
2681 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2682 if (failed(computedLayout))
2683 return emitOpError(
2684 "invalid source layout map or collapsing non-contiguous dims");
2685 expectedResultType =
2686 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2687 *computedLayout, srcType.getMemorySpace());
2688 }
2689
2690 if (expectedResultType != resultType)
2691 return emitOpError("expected collapsed type to be ")
2692 << expectedResultType << " but found " << resultType;
2693
2694 return success();
2695}
2696
2698 : public OpRewritePattern<CollapseShapeOp> {
2699public:
2700 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2701
2702 LogicalResult matchAndRewrite(CollapseShapeOp op,
2703 PatternRewriter &rewriter) const override {
2704 auto cast = op.getOperand().getDefiningOp<CastOp>();
2705 if (!cast)
2706 return failure();
2707
2708 if (!CastOp::canFoldIntoConsumerOp(cast))
2709 return failure();
2710
2711 Type newResultType = CollapseShapeOp::computeCollapsedType(
2712 llvm::cast<MemRefType>(cast.getOperand().getType()),
2713 op.getReassociationIndices());
2714
2715 if (newResultType == op.getResultType()) {
2716 rewriter.modifyOpInPlace(
2717 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2718 } else {
2719 Value newOp =
2720 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2721 op.getReassociationIndices());
2722 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2723 }
2724 return success();
2725 }
2726};
2727
2728void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2729 MLIRContext *context) {
2730 results.add<
2731 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2732 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2733 memref::DimOp, MemRefType>,
2734 CollapseShapeOpMemRefCastFolder>(context);
2735}
2736
2737OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2739 adaptor.getOperands());
2740}
2741
2742OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2744 adaptor.getOperands());
2745}
2746
2747FailureOr<std::optional<SmallVector<Value>>>
2748CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2749 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2750}
2751
2752//===----------------------------------------------------------------------===//
2753// ReshapeOp
2754//===----------------------------------------------------------------------===//
2755
2756void ReshapeOp::getAsmResultNames(
2757 function_ref<void(Value, StringRef)> setNameFn) {
2758 setNameFn(getResult(), "reshape");
2759}
2760
2761LogicalResult ReshapeOp::verify() {
2762 Type operandType = getSource().getType();
2763 Type resultType = getResult().getType();
2764
2765 Type operandElementType =
2766 llvm::cast<ShapedType>(operandType).getElementType();
2767 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2768 if (operandElementType != resultElementType)
2769 return emitOpError("element types of source and destination memref "
2770 "types should be the same");
2771
2772 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2773 if (!operandMemRefType.getLayout().isIdentity())
2774 return emitOpError("source memref type should have identity affine map");
2775
2776 int64_t shapeSize =
2777 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2778 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2779 if (resultMemRefType) {
2780 if (!resultMemRefType.getLayout().isIdentity())
2781 return emitOpError("result memref type should have identity affine map");
2782 if (shapeSize == ShapedType::kDynamic)
2783 return emitOpError("cannot use shape operand with dynamic length to "
2784 "reshape to statically-ranked memref type");
2785 if (shapeSize != resultMemRefType.getRank())
2786 return emitOpError(
2787 "length of shape operand differs from the result's memref rank");
2788 }
2789 return success();
2790}
2791
2792FailureOr<std::optional<SmallVector<Value>>>
2793ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2794 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2795}
2796
2797//===----------------------------------------------------------------------===//
2798// StoreOp
2799//===----------------------------------------------------------------------===//
2800
2801LogicalResult StoreOp::verify() {
2802 if (getNumOperands() != 2 + getMemRefType().getRank())
2803 return emitOpError("store index operand count not equal to memref rank");
2804
2805 return success();
2806}
2807
2808LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2809 SmallVectorImpl<OpFoldResult> &results) {
2810 /// store(memrefcast) -> store
2811 return foldMemRefCast(*this, getValueToStore());
2812}
2813
2814FailureOr<std::optional<SmallVector<Value>>>
2815StoreOp::bubbleDownCasts(OpBuilder &builder) {
2817 ValueRange());
2818}
2819
2820//===----------------------------------------------------------------------===//
2821// SubViewOp
2822//===----------------------------------------------------------------------===//
2823
2824void SubViewOp::getAsmResultNames(
2825 function_ref<void(Value, StringRef)> setNameFn) {
2826 setNameFn(getResult(), "subview");
2827}
2828
2829/// A subview result type can be fully inferred from the source type and the
2830/// static representation of offsets, sizes and strides. Special sentinels
2831/// encode the dynamic case.
2832MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2833 ArrayRef<int64_t> staticOffsets,
2834 ArrayRef<int64_t> staticSizes,
2835 ArrayRef<int64_t> staticStrides) {
2836 unsigned rank = sourceMemRefType.getRank();
2837 (void)rank;
2838 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2839 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2840 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2841
2842 // Extract source offset and strides.
2843 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2844
2845 // Compute target offset whose value is:
2846 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2847 int64_t targetOffset = sourceOffset;
2848 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2849 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2850 targetOffset = (SaturatedInteger::wrap(targetOffset) +
2851 SaturatedInteger::wrap(staticOffset) *
2852 SaturatedInteger::wrap(sourceStride))
2853 .asInteger();
2854 }
2855
2856 // Compute target stride whose value is:
2857 // `sourceStrides_i * staticStrides_i`.
2858 SmallVector<int64_t, 4> targetStrides;
2859 targetStrides.reserve(staticOffsets.size());
2860 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2861 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2862 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2863 SaturatedInteger::wrap(staticStride))
2864 .asInteger());
2865 }
2866
2867 // The type is now known.
2868 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2869 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2870 targetOffset, targetStrides),
2871 sourceMemRefType.getMemorySpace());
2872}
2873
2874MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2875 ArrayRef<OpFoldResult> offsets,
2876 ArrayRef<OpFoldResult> sizes,
2877 ArrayRef<OpFoldResult> strides) {
2878 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2879 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2880 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2881 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2882 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2883 if (!hasValidSizesOffsets(staticOffsets))
2884 return {};
2885 if (!hasValidSizesOffsets(staticSizes))
2886 return {};
2887 if (!hasValidStrides(staticStrides))
2888 return {};
2889 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2890 staticSizes, staticStrides);
2891}
2892
2893MemRefType SubViewOp::inferRankReducedResultType(
2894 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2895 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2896 ArrayRef<int64_t> strides) {
2897 MemRefType inferredType =
2898 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2899 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2900 "expected ");
2901 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2902 return inferredType;
2903
2904 // Compute which dimensions are dropped.
2905 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2906 computeRankReductionMask(inferredType.getShape(), resultShape);
2907 assert(dimsToProject.has_value() && "invalid rank reduction");
2908
2909 // Compute the layout and result type.
2910 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2911 SmallVector<int64_t> rankReducedStrides;
2912 rankReducedStrides.reserve(resultShape.size());
2913 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2914 if (!dimsToProject->contains(idx))
2915 rankReducedStrides.push_back(value);
2916 }
2917 return MemRefType::get(resultShape, inferredType.getElementType(),
2918 StridedLayoutAttr::get(inferredLayout.getContext(),
2919 inferredLayout.getOffset(),
2920 rankReducedStrides),
2921 inferredType.getMemorySpace());
2922}
2923
2924MemRefType SubViewOp::inferRankReducedResultType(
2925 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2926 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2927 ArrayRef<OpFoldResult> strides) {
2928 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2929 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2930 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2931 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2932 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2933 return SubViewOp::inferRankReducedResultType(
2934 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2935 staticStrides);
2936}
2937
2938// Build a SubViewOp with mixed static and dynamic entries and custom result
2939// type. If the type passed is nullptr, it is inferred.
2940void SubViewOp::build(OpBuilder &b, OperationState &result,
2941 MemRefType resultType, Value source,
2942 ArrayRef<OpFoldResult> offsets,
2943 ArrayRef<OpFoldResult> sizes,
2944 ArrayRef<OpFoldResult> strides,
2945 ArrayRef<NamedAttribute> attrs) {
2946 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2947 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2948 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2949 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2950 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2951 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2952 // Structuring implementation this way avoids duplication between builders.
2953 if (!resultType) {
2954 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2955 staticSizes, staticStrides);
2956 }
2957 result.addAttributes(attrs);
2958 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2959 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2960 b.getDenseI64ArrayAttr(staticSizes),
2961 b.getDenseI64ArrayAttr(staticStrides));
2962}
2963
2964// Build a SubViewOp with mixed static and dynamic entries and inferred result
2965// type.
2966void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2967 ArrayRef<OpFoldResult> offsets,
2968 ArrayRef<OpFoldResult> sizes,
2969 ArrayRef<OpFoldResult> strides,
2970 ArrayRef<NamedAttribute> attrs) {
2971 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2972}
2973
2974// Build a SubViewOp with static entries and inferred result type.
2975void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2976 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2977 ArrayRef<int64_t> strides,
2978 ArrayRef<NamedAttribute> attrs) {
2979 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2980 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2981 return b.getI64IntegerAttr(v);
2982 }));
2983 SmallVector<OpFoldResult> sizeValues =
2984 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2985 return b.getI64IntegerAttr(v);
2986 }));
2987 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2988 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2989 return b.getI64IntegerAttr(v);
2990 }));
2991 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2992}
2993
2994// Build a SubViewOp with dynamic entries and custom result type. If the
2995// type passed is nullptr, it is inferred.
2996void SubViewOp::build(OpBuilder &b, OperationState &result,
2997 MemRefType resultType, Value source,
2998 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2999 ArrayRef<int64_t> strides,
3000 ArrayRef<NamedAttribute> attrs) {
3001 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3002 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
3003 return b.getI64IntegerAttr(v);
3004 }));
3005 SmallVector<OpFoldResult> sizeValues =
3006 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3007 return b.getI64IntegerAttr(v);
3008 }));
3009 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3010 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3011 return b.getI64IntegerAttr(v);
3012 }));
3013 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3014 attrs);
3015}
3016
3017// Build a SubViewOp with dynamic entries and custom result type. If the type
3018// passed is nullptr, it is inferred.
3019void SubViewOp::build(OpBuilder &b, OperationState &result,
3020 MemRefType resultType, Value source, ValueRange offsets,
3021 ValueRange sizes, ValueRange strides,
3022 ArrayRef<NamedAttribute> attrs) {
3023 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3024 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3025 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3026 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3027 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3028 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3029 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3030}
3031
3032// Build a SubViewOp with dynamic entries and inferred result type.
3033void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3034 ValueRange offsets, ValueRange sizes, ValueRange strides,
3035 ArrayRef<NamedAttribute> attrs) {
3036 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3037}
3038
3039/// For ViewLikeOpInterface.
3040Value SubViewOp::getViewSource() { return getSource(); }
3041
3042/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3043/// static value).
3044static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3045 int64_t t1Offset, t2Offset;
3046 SmallVector<int64_t> t1Strides, t2Strides;
3047 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3048 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3049 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3050}
3051
3052/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3053/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3054/// marked as dropped in `droppedDims`.
3055static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3056 const llvm::SmallBitVector &droppedDims) {
3057 assert(size_t(t1.getRank()) == droppedDims.size() &&
3058 "incorrect number of bits");
3059 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3060 "incorrect number of dropped dims");
3061 int64_t t1Offset, t2Offset;
3062 SmallVector<int64_t> t1Strides, t2Strides;
3063 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3064 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3065 if (failed(res1) || failed(res2))
3066 return false;
3067 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3068 if (droppedDims[i])
3069 continue;
3070 if (t1Strides[i] != t2Strides[j])
3071 return false;
3072 ++j;
3073 }
3074 return true;
3075}
3076
3078 SubViewOp op, Type expectedType) {
3079 auto memrefType = llvm::cast<ShapedType>(expectedType);
3080 switch (result) {
3082 return success();
3084 return op->emitError("expected result rank to be smaller or equal to ")
3085 << "the source rank, but got " << op.getType();
3087 return op->emitError("expected result type to be ")
3088 << expectedType
3089 << " or a rank-reduced version. (mismatch of result sizes), but got "
3090 << op.getType();
3092 return op->emitError("expected result element type to be ")
3093 << memrefType.getElementType() << ", but got " << op.getType();
3095 return op->emitError(
3096 "expected result and source memory spaces to match, but got ")
3097 << op.getType();
3099 return op->emitError("expected result type to be ")
3100 << expectedType
3101 << " or a rank-reduced version. (mismatch of result layout), but "
3102 "got "
3103 << op.getType();
3104 }
3105 llvm_unreachable("unexpected subview verification result");
3106}
3107
3108/// Verifier for SubViewOp.
3109LogicalResult SubViewOp::verify() {
3110 MemRefType baseType = getSourceType();
3111 MemRefType subViewType = getType();
3112 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3113 ArrayRef<int64_t> staticSizes = getStaticSizes();
3114 ArrayRef<int64_t> staticStrides = getStaticStrides();
3115
3116 // The base memref and the view memref should be in the same memory space.
3117 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3118 return emitError("different memory spaces specified for base memref "
3119 "type ")
3120 << baseType << " and subview memref type " << subViewType;
3121
3122 // Verify that the base memref type has a strided layout map.
3123 if (!baseType.isStrided())
3124 return emitError("base type ") << baseType << " is not strided";
3125
3126 // Compute the expected result type, assuming that there are no rank
3127 // reductions.
3128 MemRefType expectedType = SubViewOp::inferResultType(
3129 baseType, staticOffsets, staticSizes, staticStrides);
3130
3131 // Verify all properties of a shaped type: rank, element type and dimension
3132 // sizes. This takes into account potential rank reductions.
3133 auto shapedTypeVerification = isRankReducedType(
3134 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3135 if (shapedTypeVerification != SliceVerificationResult::Success)
3136 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3137
3138 // Make sure that the memory space did not change.
3139 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3141 *this, expectedType);
3142
3143 // Verify the offset of the layout map.
3144 if (!haveCompatibleOffsets(expectedType, subViewType))
3146 *this, expectedType);
3147
3148 // The only thing that's left to verify now are the strides. First, compute
3149 // the unused dimensions due to rank reductions. We have to look at sizes and
3150 // strides to decide which dimensions were dropped. This function also
3151 // partially verifies strides in case of rank reductions.
3152 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3153 getMixedSizes());
3154 if (failed(unusedDims))
3156 *this, expectedType);
3157
3158 // Strides must match.
3159 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3161 *this, expectedType);
3162
3163 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3164 // to the base memref.
3165 SliceBoundsVerificationResult boundsResult =
3166 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3167 staticStrides, /*generateErrorMessage=*/true);
3168 if (!boundsResult.isValid)
3169 return getOperation()->emitError(boundsResult.errorMessage);
3170
3171 return success();
3172}
3173
3175 return os << "range " << range.offset << ":" << range.size << ":"
3176 << range.stride;
3177}
3178
3179/// Return the list of Range (i.e. offset, size, stride). Each Range
3180/// entry contains either the dynamic value or a ConstantIndexOp constructed
3181/// with `b` at location `loc`.
3182SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3183 OpBuilder &b, Location loc) {
3184 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3185 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3186 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3188 unsigned rank = ranks[0];
3189 res.reserve(rank);
3190 for (unsigned idx = 0; idx < rank; ++idx) {
3191 Value offset =
3192 op.isDynamicOffset(idx)
3193 ? op.getDynamicOffset(idx)
3194 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3195 Value size =
3196 op.isDynamicSize(idx)
3197 ? op.getDynamicSize(idx)
3198 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3199 Value stride =
3200 op.isDynamicStride(idx)
3201 ? op.getDynamicStride(idx)
3202 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3203 res.emplace_back(Range{offset, size, stride});
3204 }
3205 return res;
3206}
3207
3208/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3209/// to deduce the result type for the given `sourceType`. Additionally, reduce
3210/// the rank of the inferred result type if `currentResultType` is lower rank
3211/// than `currentSourceType`. Use this signature if `sourceType` is updated
3212/// together with the result type. In this case, it is important to compute
3213/// the dropped dimensions using `currentSourceType` whose strides align with
3214/// `currentResultType`.
3216 MemRefType currentResultType, MemRefType currentSourceType,
3217 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3218 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3219 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3220 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3221 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3222 currentSourceType, currentResultType, mixedSizes);
3223 if (failed(unusedDims))
3224 return nullptr;
3225
3226 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3227 SmallVector<int64_t> shape, strides;
3228 unsigned numDimsAfterReduction =
3229 nonRankReducedType.getRank() - unusedDims->count();
3230 shape.reserve(numDimsAfterReduction);
3231 strides.reserve(numDimsAfterReduction);
3232 for (const auto &[idx, size, stride] :
3233 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3234 nonRankReducedType.getShape(), layout.getStrides())) {
3235 if (unusedDims->test(idx))
3236 continue;
3237 shape.push_back(size);
3238 strides.push_back(stride);
3239 }
3240
3241 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3242 StridedLayoutAttr::get(sourceType.getContext(),
3243 layout.getOffset(), strides),
3244 nonRankReducedType.getMemorySpace());
3245}
3246
3248 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3249 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3250 unsigned rank = memrefType.getRank();
3251 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3253 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3254 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3255 targetShape, memrefType, offsets, sizes, strides);
3256 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3257 sizes, strides);
3258}
3259
3260FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3261 Value value,
3262 ArrayRef<int64_t> desiredShape) {
3263 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3264 assert(sourceMemrefType && "not a ranked memref type");
3265 auto sourceShape = sourceMemrefType.getShape();
3266 if (sourceShape.equals(desiredShape))
3267 return value;
3268 auto maybeRankReductionMask =
3269 mlir::computeRankReductionMask(sourceShape, desiredShape);
3270 if (!maybeRankReductionMask)
3271 return failure();
3272 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3273}
3274
3275/// Helper method to check if a `subview` operation is trivially a no-op. This
3276/// is the case if the all offsets are zero, all strides are 1, and the source
3277/// shape is same as the size of the subview. In such cases, the subview can
3278/// be folded into its source.
3279static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3280 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3281 return false;
3282
3283 auto mixedOffsets = subViewOp.getMixedOffsets();
3284 auto mixedSizes = subViewOp.getMixedSizes();
3285 auto mixedStrides = subViewOp.getMixedStrides();
3286
3287 // Check offsets are zero.
3288 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3289 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3290 return !intValue || intValue.value() != 0;
3291 }))
3292 return false;
3293
3294 // Check strides are one.
3295 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3296 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3297 return !intValue || intValue.value() != 1;
3298 }))
3299 return false;
3300
3301 // Check all size values are static and matches the (static) source shape.
3302 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3303 for (const auto &size : llvm::enumerate(mixedSizes)) {
3304 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3305 if (!intValue || *intValue != sourceShape[size.index()])
3306 return false;
3307 }
3308 // All conditions met. The `SubViewOp` is foldable as a no-op.
3309 return true;
3310}
3311
3312namespace {
3313/// Pattern to rewrite a subview op with MemRefCast arguments.
3314/// This essentially pushes memref.cast past its consuming subview when
3315/// `canFoldIntoConsumerOp` is true.
3316///
3317/// Example:
3318/// ```
3319/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3320/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3321/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3322/// ```
3323/// is rewritten into:
3324/// ```
3325/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3326/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3327/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3328/// ```
3329class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3330public:
3331 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3332
3333 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3334 PatternRewriter &rewriter) const override {
3335 // Any constant operand, just return to let SubViewOpConstantFolder kick
3336 // in.
3337 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3338 return matchPattern(operand, matchConstantIndex());
3339 }))
3340 return failure();
3341
3342 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3343 if (!castOp)
3344 return failure();
3345
3346 if (!CastOp::canFoldIntoConsumerOp(castOp))
3347 return failure();
3348
3349 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3350 // the MemRefCastOp source operand type to infer the result type and the
3351 // current SubViewOp source operand type to compute the dropped dimensions
3352 // if the operation is rank-reducing.
3353 auto resultType = getCanonicalSubViewResultType(
3354 subViewOp.getType(), subViewOp.getSourceType(),
3355 llvm::cast<MemRefType>(castOp.getSource().getType()),
3356 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3357 subViewOp.getMixedStrides());
3358 if (!resultType)
3359 return failure();
3360
3361 Value newSubView = SubViewOp::create(
3362 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3363 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3364 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3365 subViewOp.getStaticStrides());
3366 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3367 newSubView);
3368 return success();
3369 }
3370};
3371
3372/// Canonicalize subview ops that are no-ops. When the source shape is not
3373/// same as a result shape due to use of `affine_map`.
3374class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3375public:
3376 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3377
3378 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3379 PatternRewriter &rewriter) const override {
3380 if (!isTrivialSubViewOp(subViewOp))
3381 return failure();
3382 if (subViewOp.getSourceType() == subViewOp.getType()) {
3383 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3384 return success();
3385 }
3386 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3387 subViewOp.getSource());
3388 return success();
3389 }
3390};
3391} // namespace
3392
3393/// Return the canonical type of the result of a subview.
3395 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3396 ArrayRef<OpFoldResult> mixedSizes,
3397 ArrayRef<OpFoldResult> mixedStrides) {
3398 // Infer a memref type without taking into account any rank reductions.
3399 MemRefType resTy = SubViewOp::inferResultType(
3400 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3401 if (!resTy)
3402 return {};
3403 MemRefType nonReducedType = resTy;
3404
3405 // Directly return the non-rank reduced type if there are no dropped dims.
3406 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3407 if (droppedDims.none())
3408 return nonReducedType;
3409
3410 // Take the strides and offset from the non-rank reduced type.
3411 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3412
3413 // Drop dims from shape and strides.
3414 SmallVector<int64_t> targetShape;
3415 SmallVector<int64_t> targetStrides;
3416 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3417 if (droppedDims.test(i))
3418 continue;
3419 targetStrides.push_back(nonReducedStrides[i]);
3420 targetShape.push_back(nonReducedType.getDimSize(i));
3421 }
3422
3423 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3424 StridedLayoutAttr::get(nonReducedType.getContext(),
3425 offset, targetStrides),
3426 nonReducedType.getMemorySpace());
3427 }
3428};
3429
3430/// A canonicalizer wrapper to replace SubViewOps.
3432 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3433 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3434 }
3435};
3436
3437void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3438 MLIRContext *context) {
3439 results
3440 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3441 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3442 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3443}
3444
3445OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3446 MemRefType sourceMemrefType = getSource().getType();
3447 MemRefType resultMemrefType = getResult().getType();
3448 auto resultLayout =
3449 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3450
3451 if (resultMemrefType == sourceMemrefType &&
3452 resultMemrefType.hasStaticShape() &&
3453 (!resultLayout || resultLayout.hasStaticLayout())) {
3454 return getViewSource();
3455 }
3456
3457 // Fold subview(subview(x)), where both subviews have the same size and the
3458 // second subview's offsets are all zero. (I.e., the second subview is a
3459 // no-op.)
3460 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3461 auto srcSizes = srcSubview.getMixedSizes();
3462 auto sizes = getMixedSizes();
3463 auto offsets = getMixedOffsets();
3464 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3465 auto strides = getMixedStrides();
3466 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3467 bool allSizesSame = llvm::equal(sizes, srcSizes);
3468 if (allOffsetsZero && allStridesOne && allSizesSame &&
3469 resultMemrefType == sourceMemrefType)
3470 return getViewSource();
3471 }
3472
3473 return {};
3474}
3475
3476FailureOr<std::optional<SmallVector<Value>>>
3477SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3478 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3479}
3480
3481void SubViewOp::inferStridedMetadataRanges(
3482 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3483 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3484 auto isUninitialized =
3485 +[](IntegerValueRange range) { return range.isUninitialized(); };
3486
3487 // Bail early if any of the operands metadata is not ready:
3488 SmallVector<IntegerValueRange> offsetOperands =
3489 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3490 if (llvm::any_of(offsetOperands, isUninitialized))
3491 return;
3492
3493 SmallVector<IntegerValueRange> sizeOperands =
3494 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3495 if (llvm::any_of(sizeOperands, isUninitialized))
3496 return;
3497
3498 SmallVector<IntegerValueRange> stridesOperands =
3499 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3500 if (llvm::any_of(stridesOperands, isUninitialized))
3501 return;
3502
3503 StridedMetadataRange sourceRange =
3504 ranges[getSourceMutable().getOperandNumber()];
3505 if (sourceRange.isUninitialized())
3506 return;
3507
3508 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3509
3510 // Get the dropped dims.
3511 llvm::SmallBitVector droppedDims = getDroppedDims();
3512
3513 // Compute the new offset, strides and sizes.
3514 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3515 SmallVector<ConstantIntRanges> strides, sizes;
3516
3517 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3518 bool dropped = droppedDims.test(i);
3519 // Compute the new offset.
3520 ConstantIntRanges off =
3521 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3522 offset = intrange::inferAdd({offset, off});
3523
3524 // Skip dropped dimensions.
3525 if (dropped)
3526 continue;
3527 // Multiply the strides.
3528 strides.push_back(
3529 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3530 // Get the sizes.
3531 sizes.push_back(sizeOperands[i].getValue());
3532 }
3533
3534 setMetadata(getResult(),
3536 SmallVector<ConstantIntRanges>({std::move(offset)}),
3537 std::move(sizes), std::move(strides)));
3538}
3539
3540//===----------------------------------------------------------------------===//
3541// TransposeOp
3542//===----------------------------------------------------------------------===//
3543
3544void TransposeOp::getAsmResultNames(
3545 function_ref<void(Value, StringRef)> setNameFn) {
3546 setNameFn(getResult(), "transpose");
3547}
3548
3549/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3550static MemRefType inferTransposeResultType(MemRefType memRefType,
3551 AffineMap permutationMap) {
3552 auto originalSizes = memRefType.getShape();
3553 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3554 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3555
3556 // Compute permuted sizes and strides.
3557 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3558 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3559
3560 return MemRefType::Builder(memRefType)
3561 .setShape(sizes)
3562 .setLayout(
3563 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3564}
3565
3566void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3567 AffineMapAttr permutation,
3568 ArrayRef<NamedAttribute> attrs) {
3569 auto permutationMap = permutation.getValue();
3570 assert(permutationMap);
3571
3572 auto memRefType = llvm::cast<MemRefType>(in.getType());
3573 // Compute result type.
3574 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3575
3576 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3577 build(b, result, resultType, in, attrs);
3578}
3579
3580// transpose $in $permutation attr-dict : type($in) `to` type(results)
3581void TransposeOp::print(OpAsmPrinter &p) {
3582 p << " " << getIn() << " " << getPermutation();
3583 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3584 p << " : " << getIn().getType() << " to " << getType();
3585}
3586
3587ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3588 OpAsmParser::UnresolvedOperand in;
3589 AffineMap permutation;
3590 MemRefType srcType, dstType;
3591 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3592 parser.parseOptionalAttrDict(result.attributes) ||
3593 parser.parseColonType(srcType) ||
3594 parser.resolveOperand(in, srcType, result.operands) ||
3595 parser.parseKeywordType("to", dstType) ||
3596 parser.addTypeToList(dstType, result.types))
3597 return failure();
3598
3599 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3600 AffineMapAttr::get(permutation));
3601 return success();
3602}
3603
3604LogicalResult TransposeOp::verify() {
3605 if (!getPermutation().isPermutation())
3606 return emitOpError("expected a permutation map");
3607 if (getPermutation().getNumDims() != getIn().getType().getRank())
3608 return emitOpError("expected a permutation map of same rank as the input");
3609
3610 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3611 auto resultType = llvm::cast<MemRefType>(getType());
3612 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3613 .canonicalizeStridedLayout();
3614
3615 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3616 return emitOpError("result type ")
3617 << resultType
3618 << " is not equivalent to the canonical transposed input type "
3619 << canonicalResultType;
3620 return success();
3621}
3622
3623OpFoldResult TransposeOp::fold(FoldAdaptor) {
3624 // First check for identity permutation, we can fold it away if input and
3625 // result types are identical already.
3626 if (getPermutation().isIdentity() && getType() == getIn().getType())
3627 return getIn();
3628 // Fold two consecutive memref.transpose Ops into one by composing their
3629 // permutation maps.
3630 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3631 AffineMap composedPermutation =
3632 getPermutation().compose(otherTransposeOp.getPermutation());
3633 getInMutable().assign(otherTransposeOp.getIn());
3634 setPermutation(composedPermutation);
3635 return getResult();
3636 }
3637 return {};
3638}
3639
3640FailureOr<std::optional<SmallVector<Value>>>
3641TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3642 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3643}
3644
3645//===----------------------------------------------------------------------===//
3646// ViewOp
3647//===----------------------------------------------------------------------===//
3648
3649void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3650 setNameFn(getResult(), "view");
3651}
3652
3653LogicalResult ViewOp::verify() {
3654 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3655 auto viewType = getType();
3656
3657 // The base memref should have identity layout map (or none).
3658 if (!baseType.getLayout().isIdentity())
3659 return emitError("unsupported map for base memref type ") << baseType;
3660
3661 // The result memref should have identity layout map (or none).
3662 if (!viewType.getLayout().isIdentity())
3663 return emitError("unsupported map for result memref type ") << viewType;
3664
3665 // The base memref and the view memref should be in the same memory space.
3666 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3667 return emitError("different memory spaces specified for base memref "
3668 "type ")
3669 << baseType << " and view memref type " << viewType;
3670
3671 // Verify that we have the correct number of sizes for the result type.
3672 unsigned numDynamicDims = viewType.getNumDynamicDims();
3673 if (getSizes().size() != numDynamicDims)
3674 return emitError("incorrect number of size operands for type ") << viewType;
3675
3676 return success();
3677}
3678
3679Value ViewOp::getViewSource() { return getSource(); }
3680
3681OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3682 MemRefType sourceMemrefType = getSource().getType();
3683 MemRefType resultMemrefType = getResult().getType();
3684
3685 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3686 return getViewSource();
3687
3688 return {};
3689}
3690
3691namespace {
3692
3693struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3694 using OpRewritePattern<ViewOp>::OpRewritePattern;
3695
3696 LogicalResult matchAndRewrite(ViewOp viewOp,
3697 PatternRewriter &rewriter) const override {
3698 // Return if none of the operands are constants.
3699 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3700 return matchPattern(operand, matchConstantIndex());
3701 }))
3702 return failure();
3703
3704 // Get result memref type.
3705 auto memrefType = viewOp.getType();
3706
3707 // Get offset from old memref view type 'memRefType'.
3708 int64_t oldOffset;
3709 SmallVector<int64_t, 4> oldStrides;
3710 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3711 return failure();
3712 assert(oldOffset == 0 && "Expected 0 offset");
3713
3714 SmallVector<Value, 4> newOperands;
3715
3716 // Offset cannot be folded into result type.
3717
3718 // Fold any dynamic dim operands which are produced by a constant.
3719 SmallVector<int64_t, 4> newShapeConstants;
3720 newShapeConstants.reserve(memrefType.getRank());
3721
3722 unsigned dynamicDimPos = 0;
3723 unsigned rank = memrefType.getRank();
3724 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3725 int64_t dimSize = memrefType.getDimSize(dim);
3726 // If this is already static dimension, keep it.
3727 if (ShapedType::isStatic(dimSize)) {
3728 newShapeConstants.push_back(dimSize);
3729 continue;
3730 }
3731 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3732 if (auto constantIndexOp =
3733 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3734 // Dynamic shape dimension will be folded.
3735 newShapeConstants.push_back(constantIndexOp.value());
3736 } else {
3737 // Dynamic shape dimension not folded; copy operand from old memref.
3738 newShapeConstants.push_back(dimSize);
3739 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3740 }
3741 dynamicDimPos++;
3742 }
3743
3744 // Create new memref type with constant folded dims.
3745 MemRefType newMemRefType =
3746 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3747 // Nothing new, don't fold.
3748 if (newMemRefType == memrefType)
3749 return failure();
3750
3751 // Create new ViewOp.
3752 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3753 viewOp.getOperand(0), viewOp.getByteShift(),
3754 newOperands);
3755 // Insert a cast so we have the same type as the old memref type.
3756 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3757 return success();
3758 }
3759};
3760
3761struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3762 using OpRewritePattern<ViewOp>::OpRewritePattern;
3763
3764 LogicalResult matchAndRewrite(ViewOp viewOp,
3765 PatternRewriter &rewriter) const override {
3766 Value memrefOperand = viewOp.getOperand(0);
3767 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3768 if (!memrefCastOp)
3769 return failure();
3770 Value allocOperand = memrefCastOp.getOperand();
3771 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3772 if (!allocOp)
3773 return failure();
3774 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3775 viewOp.getByteShift(),
3776 viewOp.getSizes());
3777 return success();
3778 }
3779};
3780
3781} // namespace
3782
3783void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3784 MLIRContext *context) {
3785 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3786}
3787
3788FailureOr<std::optional<SmallVector<Value>>>
3789ViewOp::bubbleDownCasts(OpBuilder &builder) {
3790 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3791}
3792
3793//===----------------------------------------------------------------------===//
3794// AtomicRMWOp
3795//===----------------------------------------------------------------------===//
3796
3797LogicalResult AtomicRMWOp::verify() {
3798 if (getMemRefType().getRank() != getNumOperands() - 2)
3799 return emitOpError(
3800 "expects the number of subscripts to be equal to memref rank");
3801 switch (getKind()) {
3802 case arith::AtomicRMWKind::addf:
3803 case arith::AtomicRMWKind::maximumf:
3804 case arith::AtomicRMWKind::minimumf:
3805 case arith::AtomicRMWKind::mulf:
3806 if (!llvm::isa<FloatType>(getValue().getType()))
3807 return emitOpError() << "with kind '"
3808 << arith::stringifyAtomicRMWKind(getKind())
3809 << "' expects a floating-point type";
3810 break;
3811 case arith::AtomicRMWKind::addi:
3812 case arith::AtomicRMWKind::maxs:
3813 case arith::AtomicRMWKind::maxu:
3814 case arith::AtomicRMWKind::mins:
3815 case arith::AtomicRMWKind::minu:
3816 case arith::AtomicRMWKind::muli:
3817 case arith::AtomicRMWKind::ori:
3818 case arith::AtomicRMWKind::xori:
3819 case arith::AtomicRMWKind::andi:
3820 if (!llvm::isa<IntegerType>(getValue().getType()))
3821 return emitOpError() << "with kind '"
3822 << arith::stringifyAtomicRMWKind(getKind())
3823 << "' expects an integer type";
3824 break;
3825 default:
3826 break;
3827 }
3828 return success();
3829}
3830
3831OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3832 /// atomicrmw(memrefcast) -> atomicrmw
3833 if (succeeded(foldMemRefCast(*this, getValue())))
3834 return getResult();
3835 return OpFoldResult();
3836}
3837
3838FailureOr<std::optional<SmallVector<Value>>>
3839AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3841 getResult());
3842}
3843
3844//===----------------------------------------------------------------------===//
3845// TableGen'd op method definitions
3846//===----------------------------------------------------------------------===//
3847
3848#define GET_OP_CLASSES
3849#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:67
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:250
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IndexType getIndexType()
Definition Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.