MLIR 22.0.0git
MemRefOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/AffineMap.h"
14#include "mlir/IR/Builders.h"
16#include "mlir/IR/Matchers.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallBitVector.h"
26
27using namespace mlir;
28using namespace mlir::memref;
29
30/// Materialize a single constant operation from a given attribute value with
31/// the desired resultant type.
32Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
33 Attribute value, Type type,
34 Location loc) {
35 return arith::ConstantOp::materialize(builder, value, type, loc);
36}
37
38//===----------------------------------------------------------------------===//
39// Common canonicalization pattern support logic
40//===----------------------------------------------------------------------===//
41
42/// This is a common class used for patterns of the form
43/// "someop(memrefcast) -> someop". It folds the source of any memref.cast
44/// into the root operation directly.
45LogicalResult mlir::memref::foldMemRefCast(Operation *op, Value inner) {
46 bool folded = false;
47 for (OpOperand &operand : op->getOpOperands()) {
48 auto cast = operand.get().getDefiningOp<CastOp>();
49 if (cast && operand.get() != inner &&
50 !llvm::isa<UnrankedMemRefType>(cast.getOperand().getType())) {
51 operand.set(cast.getOperand());
52 folded = true;
53 }
54 }
55 return success(folded);
56}
57
58/// Return an unranked/ranked tensor type for the given unranked/ranked memref
59/// type.
61 if (auto memref = llvm::dyn_cast<MemRefType>(type))
62 return RankedTensorType::get(memref.getShape(), memref.getElementType());
63 if (auto memref = llvm::dyn_cast<UnrankedMemRefType>(type))
64 return UnrankedTensorType::get(memref.getElementType());
65 return NoneType::get(type.getContext());
66}
67
69 int64_t dim) {
70 auto memrefType = llvm::cast<MemRefType>(value.getType());
71 if (memrefType.isDynamicDim(dim))
72 return builder.createOrFold<memref::DimOp>(loc, value, dim);
73
74 return builder.getIndexAttr(memrefType.getDimSize(dim));
75}
76
78 Location loc, Value value) {
79 auto memrefType = llvm::cast<MemRefType>(value.getType());
81 for (int64_t i = 0; i < memrefType.getRank(); ++i)
82 result.push_back(getMixedSize(builder, loc, value, i));
83 return result;
84}
85
86//===----------------------------------------------------------------------===//
87// Utility functions for propagating static information
88//===----------------------------------------------------------------------===//
89
90/// Helper function that sets values[i] to constValues[i] if the latter is a
91/// static value, as indicated by ShapedType::kDynamic.
92///
93/// If constValues[i] is dynamic, tries to extract a constant value from
94/// value[i] to allow for additional folding opportunities. Also convertes all
95/// existing attributes to index attributes. (They may be i64 attributes.)
97 ArrayRef<int64_t> constValues) {
98 assert(constValues.size() == values.size() &&
99 "incorrect number of const values");
100 for (auto [i, cstVal] : llvm::enumerate(constValues)) {
101 Builder builder(values[i].getContext());
102 if (ShapedType::isStatic(cstVal)) {
103 // Constant value is known, use it directly.
104 values[i] = builder.getIndexAttr(cstVal);
105 continue;
106 }
107 if (std::optional<int64_t> cst = getConstantIntValue(values[i])) {
108 // Try to extract a constant or convert an existing to index.
109 values[i] = builder.getIndexAttr(*cst);
110 }
111 }
112}
113
114/// Helper function to retrieve a lossless memory-space cast, and the
115/// corresponding new result memref type.
116static std::tuple<MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type>
118 MemorySpaceCastOpInterface castOp =
119 MemorySpaceCastOpInterface::getIfPromotableCast(src);
120
121 // Bail if the cast is not lossless.
122 if (!castOp)
123 return {};
124
125 // Transform the source and target type of `castOp` to have the same metadata
126 // as `resultTy`. Bail if not possible.
127 FailureOr<PtrLikeTypeInterface> srcTy = resultTy.clonePtrWith(
128 castOp.getSourcePtr().getType().getMemorySpace(), std::nullopt);
129 if (failed(srcTy))
130 return {};
131
132 FailureOr<PtrLikeTypeInterface> tgtTy = resultTy.clonePtrWith(
133 castOp.getTargetPtr().getType().getMemorySpace(), std::nullopt);
134 if (failed(tgtTy))
135 return {};
136
137 // Check if this is a valid memory-space cast.
138 if (!castOp.isValidMemorySpaceCast(*tgtTy, *srcTy))
139 return {};
140
141 return std::make_tuple(castOp, *tgtTy, *srcTy);
142}
143
144/// Implementation of `bubbleDownCasts` method for memref operations that
145/// return a single memref result.
146template <typename ConcreteOpTy>
147static FailureOr<std::optional<SmallVector<Value>>>
149 OpOperand &src) {
150 auto [castOp, tgtTy, resTy] = getMemorySpaceCastInfo(op.getType(), src.get());
151 // Bail if we cannot cast.
152 if (!castOp)
153 return failure();
154
155 // Create the new operands.
156 SmallVector<Value> operands;
157 llvm::append_range(operands, op->getOperands());
158 operands[src.getOperandNumber()] = castOp.getSourcePtr();
159
160 // Create the new op and results.
161 auto newOp = ConcreteOpTy::create(
162 builder, op.getLoc(), TypeRange(resTy), operands, op.getProperties(),
163 llvm::to_vector_of<NamedAttribute>(op->getDiscardableAttrs()));
164
165 // Insert a memory-space cast to the original memory space of the op.
166 MemorySpaceCastOpInterface result = castOp.cloneMemorySpaceCastOp(
167 builder, tgtTy,
168 cast<TypedValue<PtrLikeTypeInterface>>(newOp.getResult()));
169 return std::optional<SmallVector<Value>>(
170 SmallVector<Value>({result.getTargetPtr()}));
171}
172
173//===----------------------------------------------------------------------===//
174// AllocOp / AllocaOp
175//===----------------------------------------------------------------------===//
176
177void AllocOp::getAsmResultNames(
178 function_ref<void(Value, StringRef)> setNameFn) {
179 setNameFn(getResult(), "alloc");
180}
181
182void AllocaOp::getAsmResultNames(
183 function_ref<void(Value, StringRef)> setNameFn) {
184 setNameFn(getResult(), "alloca");
185}
186
187template <typename AllocLikeOp>
188static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
189 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
190 "applies to only alloc or alloca");
191 auto memRefType = llvm::dyn_cast<MemRefType>(op.getResult().getType());
192 if (!memRefType)
193 return op.emitOpError("result must be a memref");
194
195 if (op.getDynamicSizes().size() != memRefType.getNumDynamicDims())
196 return op.emitOpError("dimension operand count does not equal memref "
197 "dynamic dimension count");
198
199 unsigned numSymbols = 0;
200 if (!memRefType.getLayout().isIdentity())
201 numSymbols = memRefType.getLayout().getAffineMap().getNumSymbols();
202 if (op.getSymbolOperands().size() != numSymbols)
203 return op.emitOpError("symbol operand count does not equal memref symbol "
204 "count: expected ")
205 << numSymbols << ", got " << op.getSymbolOperands().size();
206
207 return success();
208}
209
210LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
211
212LogicalResult AllocaOp::verify() {
213 // An alloca op needs to have an ancestor with an allocation scope trait.
214 if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
215 return emitOpError(
216 "requires an ancestor op with AutomaticAllocationScope trait");
217
218 return verifyAllocLikeOp(*this);
219}
220
221namespace {
222/// Fold constant dimensions into an alloc like operation.
223template <typename AllocLikeOp>
224struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
225 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
226
227 LogicalResult matchAndRewrite(AllocLikeOp alloc,
228 PatternRewriter &rewriter) const override {
229 // Check to see if any dimensions operands are constants. If so, we can
230 // substitute and drop them.
231 if (llvm::none_of(alloc.getDynamicSizes(), [](Value operand) {
232 APInt constSizeArg;
233 if (!matchPattern(operand, m_ConstantInt(&constSizeArg)))
234 return false;
235 return constSizeArg.isNonNegative();
236 }))
237 return failure();
238
239 auto memrefType = alloc.getType();
240
241 // Ok, we have one or more constant operands. Collect the non-constant ones
242 // and keep track of the resultant memref type to build.
243 SmallVector<int64_t, 4> newShapeConstants;
244 newShapeConstants.reserve(memrefType.getRank());
245 SmallVector<Value, 4> dynamicSizes;
246
247 unsigned dynamicDimPos = 0;
248 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
249 int64_t dimSize = memrefType.getDimSize(dim);
250 // If this is already static dimension, keep it.
251 if (ShapedType::isStatic(dimSize)) {
252 newShapeConstants.push_back(dimSize);
253 continue;
254 }
255 auto dynamicSize = alloc.getDynamicSizes()[dynamicDimPos];
256 APInt constSizeArg;
257 if (matchPattern(dynamicSize, m_ConstantInt(&constSizeArg)) &&
258 constSizeArg.isNonNegative()) {
259 // Dynamic shape dimension will be folded.
260 newShapeConstants.push_back(constSizeArg.getZExtValue());
261 } else {
262 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
263 newShapeConstants.push_back(ShapedType::kDynamic);
264 dynamicSizes.push_back(dynamicSize);
265 }
266 dynamicDimPos++;
267 }
268
269 // Create new memref type (which will have fewer dynamic dimensions).
270 MemRefType newMemRefType =
271 MemRefType::Builder(memrefType).setShape(newShapeConstants);
272 assert(dynamicSizes.size() == newMemRefType.getNumDynamicDims());
273
274 // Create and insert the alloc op for the new memref.
275 auto newAlloc = AllocLikeOp::create(rewriter, alloc.getLoc(), newMemRefType,
276 dynamicSizes, alloc.getSymbolOperands(),
277 alloc.getAlignmentAttr());
278 // Insert a cast so we have the same type as the old alloc.
279 rewriter.replaceOpWithNewOp<CastOp>(alloc, alloc.getType(), newAlloc);
280 return success();
281 }
282};
283
284/// Fold alloc operations with no users or only store and dealloc uses.
285template <typename T>
286struct SimplifyDeadAlloc : public OpRewritePattern<T> {
287 using OpRewritePattern<T>::OpRewritePattern;
288
289 LogicalResult matchAndRewrite(T alloc,
290 PatternRewriter &rewriter) const override {
291 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
292 if (auto storeOp = dyn_cast<StoreOp>(op))
293 return storeOp.getValue() == alloc;
294 return !isa<DeallocOp>(op);
295 }))
296 return failure();
297
298 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
299 rewriter.eraseOp(user);
300
301 rewriter.eraseOp(alloc);
302 return success();
303 }
304};
305} // namespace
306
307void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
308 MLIRContext *context) {
309 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
310}
311
312void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
313 MLIRContext *context) {
314 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
315 context);
316}
317
318//===----------------------------------------------------------------------===//
319// ReallocOp
320//===----------------------------------------------------------------------===//
321
322LogicalResult ReallocOp::verify() {
323 auto sourceType = llvm::cast<MemRefType>(getOperand(0).getType());
324 MemRefType resultType = getType();
325
326 // The source memref should have identity layout (or none).
327 if (!sourceType.getLayout().isIdentity())
328 return emitError("unsupported layout for source memref type ")
329 << sourceType;
330
331 // The result memref should have identity layout (or none).
332 if (!resultType.getLayout().isIdentity())
333 return emitError("unsupported layout for result memref type ")
334 << resultType;
335
336 // The source memref and the result memref should be in the same memory space.
337 if (sourceType.getMemorySpace() != resultType.getMemorySpace())
338 return emitError("different memory spaces specified for source memref "
339 "type ")
340 << sourceType << " and result memref type " << resultType;
341
342 // The source memref and the result memref should have the same element type.
343 if (sourceType.getElementType() != resultType.getElementType())
344 return emitError("different element types specified for source memref "
345 "type ")
346 << sourceType << " and result memref type " << resultType;
347
348 // Verify that we have the dynamic dimension operand when it is needed.
349 if (resultType.getNumDynamicDims() && !getDynamicResultSize())
350 return emitError("missing dimension operand for result type ")
351 << resultType;
352 if (!resultType.getNumDynamicDims() && getDynamicResultSize())
353 return emitError("unnecessary dimension operand for result type ")
354 << resultType;
355
356 return success();
357}
358
359void ReallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
360 MLIRContext *context) {
361 results.add<SimplifyDeadAlloc<ReallocOp>>(context);
362}
363
364//===----------------------------------------------------------------------===//
365// AllocaScopeOp
366//===----------------------------------------------------------------------===//
367
368void AllocaScopeOp::print(OpAsmPrinter &p) {
369 bool printBlockTerminators = false;
370
371 p << ' ';
372 if (!getResults().empty()) {
373 p << " -> (" << getResultTypes() << ")";
374 printBlockTerminators = true;
375 }
376 p << ' ';
377 p.printRegion(getBodyRegion(),
378 /*printEntryBlockArgs=*/false,
379 /*printBlockTerminators=*/printBlockTerminators);
380 p.printOptionalAttrDict((*this)->getAttrs());
381}
382
383ParseResult AllocaScopeOp::parse(OpAsmParser &parser, OperationState &result) {
384 // Create a region for the body.
385 result.regions.reserve(1);
386 Region *bodyRegion = result.addRegion();
387
388 // Parse optional results type list.
389 if (parser.parseOptionalArrowTypeList(result.types))
390 return failure();
391
392 // Parse the body region.
393 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}))
394 return failure();
395 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
396 result.location);
397
398 // Parse the optional attribute list.
399 if (parser.parseOptionalAttrDict(result.attributes))
400 return failure();
401
402 return success();
403}
404
405void AllocaScopeOp::getSuccessorRegions(
407 if (!point.isParent()) {
408 regions.push_back(RegionSuccessor(getOperation(), getResults()));
409 return;
410 }
411
412 regions.push_back(RegionSuccessor(&getBodyRegion()));
413}
414
415/// Given an operation, return whether this op is guaranteed to
416/// allocate an AutomaticAllocationScopeResource
418 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
419 if (!interface)
420 return false;
421 for (auto res : op->getResults()) {
422 if (auto effect =
423 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
424 if (isa<SideEffects::AutomaticAllocationScopeResource>(
425 effect->getResource()))
426 return true;
427 }
428 }
429 return false;
430}
431
432/// Given an operation, return whether this op itself could
433/// allocate an AutomaticAllocationScopeResource. Note that
434/// this will not check whether an operation contained within
435/// the op can allocate.
437 // This op itself doesn't create a stack allocation,
438 // the inner allocation should be handled separately.
440 return false;
441 MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
442 if (!interface)
443 return true;
444 for (auto res : op->getResults()) {
445 if (auto effect =
446 interface.getEffectOnValue<MemoryEffects::Allocate>(res)) {
447 if (isa<SideEffects::AutomaticAllocationScopeResource>(
448 effect->getResource()))
449 return true;
450 }
451 }
452 return false;
453}
454
455/// Return whether this op is the last non terminating op
456/// in a region. That is to say, it is in a one-block region
457/// and is only followed by a terminator. This prevents
458/// extending the lifetime of allocations.
460 return op->getBlock()->mightHaveTerminator() &&
461 op->getNextNode() == op->getBlock()->getTerminator() &&
463}
464
465/// Inline an AllocaScopeOp if either the direct parent is an allocation scope
466/// or it contains no allocation.
467struct AllocaScopeInliner : public OpRewritePattern<AllocaScopeOp> {
468 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
469
470 LogicalResult matchAndRewrite(AllocaScopeOp op,
471 PatternRewriter &rewriter) const override {
472 bool hasPotentialAlloca =
473 op->walk<WalkOrder::PreOrder>([&](Operation *alloc) {
474 if (alloc == op)
475 return WalkResult::advance();
477 return WalkResult::interrupt();
478 if (alloc->hasTrait<OpTrait::AutomaticAllocationScope>())
479 return WalkResult::skip();
480 return WalkResult::advance();
481 }).wasInterrupted();
482
483 // If this contains no potential allocation, it is always legal to
484 // inline. Otherwise, consider two conditions:
485 if (hasPotentialAlloca) {
486 // If the parent isn't an allocation scope, or we are not the last
487 // non-terminator op in the parent, we will extend the lifetime.
488 if (!op->getParentOp()->hasTrait<OpTrait::AutomaticAllocationScope>())
489 return failure();
491 return failure();
492 }
493
494 Block *block = &op.getRegion().front();
495 Operation *terminator = block->getTerminator();
496 ValueRange results = terminator->getOperands();
497 rewriter.inlineBlockBefore(block, op);
498 rewriter.replaceOp(op, results);
499 rewriter.eraseOp(terminator);
500 return success();
501 }
502};
503
504/// Move allocations into an allocation scope, if it is legal to
505/// move them (e.g. their operands are available at the location
506/// the op would be moved to).
507struct AllocaScopeHoister : public OpRewritePattern<AllocaScopeOp> {
508 using OpRewritePattern<AllocaScopeOp>::OpRewritePattern;
509
510 LogicalResult matchAndRewrite(AllocaScopeOp op,
511 PatternRewriter &rewriter) const override {
512
513 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
514 return failure();
515
516 Operation *lastParentWithoutScope = op->getParentOp();
517
518 if (!lastParentWithoutScope ||
519 lastParentWithoutScope->hasTrait<OpTrait::AutomaticAllocationScope>())
520 return failure();
521
522 // Only apply to if this is this last non-terminator
523 // op in the block (lest lifetime be extended) of a one
524 // block region
525 if (!lastNonTerminatorInRegion(op) ||
526 !lastNonTerminatorInRegion(lastParentWithoutScope))
527 return failure();
528
529 while (!lastParentWithoutScope->getParentOp()
531 lastParentWithoutScope = lastParentWithoutScope->getParentOp();
532 if (!lastParentWithoutScope ||
533 !lastNonTerminatorInRegion(lastParentWithoutScope))
534 return failure();
535 }
536 assert(lastParentWithoutScope->getParentOp()
538
539 Region *containingRegion = nullptr;
540 for (auto &r : lastParentWithoutScope->getRegions()) {
541 if (r.isAncestor(op->getParentRegion())) {
542 assert(containingRegion == nullptr &&
543 "only one region can contain the op");
544 containingRegion = &r;
545 }
546 }
547 assert(containingRegion && "op must be contained in a region");
548
550 op->walk([&](Operation *alloc) {
552 return WalkResult::skip();
553
554 // If any operand is not defined before the location of
555 // lastParentWithoutScope (i.e. where we would hoist to), skip.
556 if (llvm::any_of(alloc->getOperands(), [&](Value v) {
557 return containingRegion->isAncestor(v.getParentRegion());
558 }))
559 return WalkResult::skip();
560 toHoist.push_back(alloc);
561 return WalkResult::advance();
562 });
563
564 if (toHoist.empty())
565 return failure();
566 rewriter.setInsertionPoint(lastParentWithoutScope);
567 for (auto *op : toHoist) {
568 auto *cloned = rewriter.clone(*op);
569 rewriter.replaceOp(op, cloned->getResults());
570 }
571 return success();
572 }
573};
574
575void AllocaScopeOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results.add<AllocaScopeInliner, AllocaScopeHoister>(context);
578}
579
580//===----------------------------------------------------------------------===//
581// AssumeAlignmentOp
582//===----------------------------------------------------------------------===//
583
584LogicalResult AssumeAlignmentOp::verify() {
585 if (!llvm::isPowerOf2_32(getAlignment()))
586 return emitOpError("alignment must be power of 2");
587 return success();
588}
589
590void AssumeAlignmentOp::getAsmResultNames(
591 function_ref<void(Value, StringRef)> setNameFn) {
592 setNameFn(getResult(), "assume_align");
593}
594
595OpFoldResult AssumeAlignmentOp::fold(FoldAdaptor adaptor) {
596 auto source = getMemref().getDefiningOp<AssumeAlignmentOp>();
597 if (!source)
598 return {};
599 if (source.getAlignment() != getAlignment())
600 return {};
601 return getMemref();
602}
603
604FailureOr<std::optional<SmallVector<Value>>>
605AssumeAlignmentOp::bubbleDownCasts(OpBuilder &builder) {
606 return bubbleDownCastsPassthroughOpImpl(*this, builder, getMemrefMutable());
607}
608
609//===----------------------------------------------------------------------===//
610// DistinctObjectsOp
611//===----------------------------------------------------------------------===//
612
613LogicalResult DistinctObjectsOp::verify() {
614 if (getOperandTypes() != getResultTypes())
615 return emitOpError("operand types and result types must match");
616
617 if (getOperandTypes().empty())
618 return emitOpError("expected at least one operand");
619
620 return success();
621}
622
623LogicalResult DistinctObjectsOp::inferReturnTypes(
624 MLIRContext * /*context*/, std::optional<Location> /*location*/,
625 ValueRange operands, DictionaryAttr /*attributes*/,
626 OpaqueProperties /*properties*/, RegionRange /*regions*/,
627 SmallVectorImpl<Type> &inferredReturnTypes) {
628 llvm::copy(operands.getTypes(), std::back_inserter(inferredReturnTypes));
629 return success();
630}
631
632//===----------------------------------------------------------------------===//
633// CastOp
634//===----------------------------------------------------------------------===//
635
636void CastOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
637 setNameFn(getResult(), "cast");
638}
639
640/// Determines whether MemRef_CastOp casts to a more dynamic version of the
641/// source memref. This is useful to fold a memref.cast into a consuming op
642/// and implement canonicalization patterns for ops in different dialects that
643/// may consume the results of memref.cast operations. Such foldable memref.cast
644/// operations are typically inserted as `view` and `subview` ops are
645/// canonicalized, to preserve the type compatibility of their uses.
646///
647/// Returns true when all conditions are met:
648/// 1. source and result are ranked memrefs with strided semantics and same
649/// element type and rank.
650/// 2. each of the source's size, offset or stride has more static information
651/// than the corresponding result's size, offset or stride.
652///
653/// Example 1:
654/// ```mlir
655/// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
656/// %2 = consumer %1 ... : memref<?x?xf32> ...
657/// ```
658///
659/// may fold into:
660///
661/// ```mlir
662/// %2 = consumer %0 ... : memref<8x16xf32> ...
663/// ```
664///
665/// Example 2:
666/// ```
667/// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
668/// to memref<?x?xf32>
669/// consumer %1 : memref<?x?xf32> ...
670/// ```
671///
672/// may fold into:
673///
674/// ```
675/// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
676/// ```
677bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
678 MemRefType sourceType =
679 llvm::dyn_cast<MemRefType>(castOp.getSource().getType());
680 MemRefType resultType = llvm::dyn_cast<MemRefType>(castOp.getType());
681
682 // Requires ranked MemRefType.
683 if (!sourceType || !resultType)
684 return false;
685
686 // Requires same elemental type.
687 if (sourceType.getElementType() != resultType.getElementType())
688 return false;
689
690 // Requires same rank.
691 if (sourceType.getRank() != resultType.getRank())
692 return false;
693
694 // Only fold casts between strided memref forms.
695 int64_t sourceOffset, resultOffset;
696 SmallVector<int64_t, 4> sourceStrides, resultStrides;
697 if (failed(sourceType.getStridesAndOffset(sourceStrides, sourceOffset)) ||
698 failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
699 return false;
700
701 // If cast is towards more static sizes along any dimension, don't fold.
702 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
703 auto ss = std::get<0>(it), st = std::get<1>(it);
704 if (ss != st)
705 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
706 return false;
707 }
708
709 // If cast is towards more static offset along any dimension, don't fold.
710 if (sourceOffset != resultOffset)
711 if (ShapedType::isDynamic(sourceOffset) &&
712 ShapedType::isStatic(resultOffset))
713 return false;
714
715 // If cast is towards more static strides along any dimension, don't fold.
716 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
717 auto ss = std::get<0>(it), st = std::get<1>(it);
718 if (ss != st)
719 if (ShapedType::isDynamic(ss) && ShapedType::isStatic(st))
720 return false;
721 }
722
723 return true;
724}
725
726bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
727 if (inputs.size() != 1 || outputs.size() != 1)
728 return false;
729 Type a = inputs.front(), b = outputs.front();
730 auto aT = llvm::dyn_cast<MemRefType>(a);
731 auto bT = llvm::dyn_cast<MemRefType>(b);
732
733 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
734 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
735
736 if (aT && bT) {
737 if (aT.getElementType() != bT.getElementType())
738 return false;
739 if (aT.getLayout() != bT.getLayout()) {
740 int64_t aOffset, bOffset;
741 SmallVector<int64_t, 4> aStrides, bStrides;
742 if (failed(aT.getStridesAndOffset(aStrides, aOffset)) ||
743 failed(bT.getStridesAndOffset(bStrides, bOffset)) ||
744 aStrides.size() != bStrides.size())
745 return false;
746
747 // Strides along a dimension/offset are compatible if the value in the
748 // source memref is static and the value in the target memref is the
749 // same. They are also compatible if either one is dynamic (see
750 // description of MemRefCastOp for details).
751 auto checkCompatible = [](int64_t a, int64_t b) {
752 return (ShapedType::isDynamic(a) || ShapedType::isDynamic(b) || a == b);
753 };
754 if (!checkCompatible(aOffset, bOffset))
755 return false;
756 for (const auto &aStride : enumerate(aStrides))
757 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
758 return false;
759 }
760 if (aT.getMemorySpace() != bT.getMemorySpace())
761 return false;
762
763 // They must have the same rank, and any specified dimensions must match.
764 if (aT.getRank() != bT.getRank())
765 return false;
766
767 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
768 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
769 if (ShapedType::isStatic(aDim) && ShapedType::isStatic(bDim) &&
770 aDim != bDim)
771 return false;
772 }
773 return true;
774 } else {
775 if (!aT && !uaT)
776 return false;
777 if (!bT && !ubT)
778 return false;
779 // Unranked to unranked casting is unsupported
780 if (uaT && ubT)
781 return false;
782
783 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
784 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
785 if (aEltType != bEltType)
786 return false;
787
788 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
789 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
790 return aMemSpace == bMemSpace;
791 }
792
793 return false;
794}
795
796OpFoldResult CastOp::fold(FoldAdaptor adaptor) {
797 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
798}
799
800FailureOr<std::optional<SmallVector<Value>>>
801CastOp::bubbleDownCasts(OpBuilder &builder) {
802 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
803}
804
805//===----------------------------------------------------------------------===//
806// CopyOp
807//===----------------------------------------------------------------------===//
808
809namespace {
810
811/// Fold memref.copy(%x, %x).
812struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
813 using OpRewritePattern<CopyOp>::OpRewritePattern;
814
815 LogicalResult matchAndRewrite(CopyOp copyOp,
816 PatternRewriter &rewriter) const override {
817 if (copyOp.getSource() != copyOp.getTarget())
818 return failure();
819
820 rewriter.eraseOp(copyOp);
821 return success();
822 }
823};
824
825struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
826 using OpRewritePattern<CopyOp>::OpRewritePattern;
827
828 static bool isEmptyMemRef(BaseMemRefType type) {
829 return type.hasRank() && llvm::is_contained(type.getShape(), 0);
830 }
831
832 LogicalResult matchAndRewrite(CopyOp copyOp,
833 PatternRewriter &rewriter) const override {
834 if (isEmptyMemRef(copyOp.getSource().getType()) ||
835 isEmptyMemRef(copyOp.getTarget().getType())) {
836 rewriter.eraseOp(copyOp);
837 return success();
838 }
839
840 return failure();
841 }
842};
843} // namespace
844
845void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
846 MLIRContext *context) {
847 results.add<FoldEmptyCopy, FoldSelfCopy>(context);
848}
849
850/// If the source/target of a CopyOp is a CastOp that does not modify the shape
851/// and element type, the cast can be skipped. Such CastOps only cast the layout
852/// of the type.
853static LogicalResult FoldCopyOfCast(CopyOp op) {
854 for (OpOperand &operand : op->getOpOperands()) {
855 auto castOp = operand.get().getDefiningOp<memref::CastOp>();
856 if (castOp && memref::CastOp::canFoldIntoConsumerOp(castOp)) {
857 operand.set(castOp.getOperand());
858 return success();
859 }
860 }
861 return failure();
862}
863
864LogicalResult CopyOp::fold(FoldAdaptor adaptor,
866
867 /// copy(memrefcast) -> copy
868 return FoldCopyOfCast(*this);
869}
870
871//===----------------------------------------------------------------------===//
872// DeallocOp
873//===----------------------------------------------------------------------===//
874
875LogicalResult DeallocOp::fold(FoldAdaptor adaptor,
877 /// dealloc(memrefcast) -> dealloc
878 return foldMemRefCast(*this);
879}
880
881//===----------------------------------------------------------------------===//
882// DimOp
883//===----------------------------------------------------------------------===//
884
885void DimOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
886 setNameFn(getResult(), "dim");
887}
888
889void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
890 int64_t index) {
891 auto loc = result.location;
892 Value indexValue = arith::ConstantIndexOp::create(builder, loc, index);
893 build(builder, result, source, indexValue);
894}
895
896std::optional<int64_t> DimOp::getConstantIndex() {
898}
899
900Speculation::Speculatability DimOp::getSpeculatability() {
901 auto constantIndex = getConstantIndex();
902 if (!constantIndex)
904
905 auto rankedSourceType = dyn_cast<MemRefType>(getSource().getType());
906 if (!rankedSourceType)
908
909 if (rankedSourceType.getRank() <= constantIndex)
911
913}
914
915void DimOp::inferResultRangesFromOptional(ArrayRef<IntegerValueRange> argRanges,
916 SetIntLatticeFn setResultRange) {
917 setResultRange(getResult(),
918 intrange::inferShapedDimOpInterface(*this, argRanges[1]));
919}
920
921/// Return a map with key being elements in `vals` and data being number of
922/// occurences of it. Use std::map, since the `vals` here are strides and the
923/// dynamic stride value is the same as the tombstone value for
924/// `DenseMap<int64_t>`.
925static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
926 std::map<int64_t, unsigned> numOccurences;
927 for (auto val : vals)
928 numOccurences[val]++;
929 return numOccurences;
930}
931
932/// Given the `originalType` and a `candidateReducedType` whose shape is assumed
933/// to be a subset of `originalType` with some `1` entries erased, return the
934/// set of indices that specifies which of the entries of `originalShape` are
935/// dropped to obtain `reducedShape`.
936/// This accounts for cases where there are multiple unit-dims, but only a
937/// subset of those are dropped. For MemRefTypes these can be disambiguated
938/// using the strides. If a dimension is dropped the stride must be dropped too.
939static FailureOr<llvm::SmallBitVector>
940computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
942 llvm::SmallBitVector unusedDims(originalType.getRank());
943 if (originalType.getRank() == reducedType.getRank())
944 return unusedDims;
945
946 for (const auto &dim : llvm::enumerate(sizes))
947 if (auto attr = llvm::dyn_cast_if_present<Attribute>(dim.value()))
948 if (llvm::cast<IntegerAttr>(attr).getInt() == 1)
949 unusedDims.set(dim.index());
950
951 // Early exit for the case where the number of unused dims matches the number
952 // of ranks reduced.
953 if (static_cast<int64_t>(unusedDims.count()) + reducedType.getRank() ==
954 originalType.getRank())
955 return unusedDims;
956
957 SmallVector<int64_t> originalStrides, candidateStrides;
958 int64_t originalOffset, candidateOffset;
959 if (failed(
960 originalType.getStridesAndOffset(originalStrides, originalOffset)) ||
961 failed(
962 reducedType.getStridesAndOffset(candidateStrides, candidateOffset)))
963 return failure();
964
965 // For memrefs, a dimension is truly dropped if its corresponding stride is
966 // also dropped. This is particularly important when more than one of the dims
967 // is 1. Track the number of occurences of the strides in the original type
968 // and the candidate type. For each unused dim that stride should not be
969 // present in the candidate type. Note that there could be multiple dimensions
970 // that have the same size. We dont need to exactly figure out which dim
971 // corresponds to which stride, we just need to verify that the number of
972 // reptitions of a stride in the original + number of unused dims with that
973 // stride == number of repititions of a stride in the candidate.
974 std::map<int64_t, unsigned> currUnaccountedStrides =
975 getNumOccurences(originalStrides);
976 std::map<int64_t, unsigned> candidateStridesNumOccurences =
977 getNumOccurences(candidateStrides);
978 for (size_t dim = 0, e = unusedDims.size(); dim != e; ++dim) {
979 if (!unusedDims.test(dim))
980 continue;
981 int64_t originalStride = originalStrides[dim];
982 if (currUnaccountedStrides[originalStride] >
983 candidateStridesNumOccurences[originalStride]) {
984 // This dim can be treated as dropped.
985 currUnaccountedStrides[originalStride]--;
986 continue;
987 }
988 if (currUnaccountedStrides[originalStride] ==
989 candidateStridesNumOccurences[originalStride]) {
990 // The stride for this is not dropped. Keep as is.
991 unusedDims.reset(dim);
992 continue;
993 }
994 if (currUnaccountedStrides[originalStride] <
995 candidateStridesNumOccurences[originalStride]) {
996 // This should never happen. Cant have a stride in the reduced rank type
997 // that wasnt in the original one.
998 return failure();
999 }
1000 }
1001
1002 if ((int64_t)unusedDims.count() + reducedType.getRank() !=
1003 originalType.getRank())
1004 return failure();
1005 return unusedDims;
1006}
1007
1008llvm::SmallBitVector SubViewOp::getDroppedDims() {
1009 MemRefType sourceType = getSourceType();
1010 MemRefType resultType = getType();
1011 FailureOr<llvm::SmallBitVector> unusedDims =
1012 computeMemRefRankReductionMask(sourceType, resultType, getMixedSizes());
1013 assert(succeeded(unusedDims) && "unable to find unused dims of subview");
1014 return *unusedDims;
1015}
1016
1017OpFoldResult DimOp::fold(FoldAdaptor adaptor) {
1018 // All forms of folding require a known index.
1019 auto index = llvm::dyn_cast_if_present<IntegerAttr>(adaptor.getIndex());
1020 if (!index)
1021 return {};
1022
1023 // Folding for unranked types (UnrankedMemRefType) is not supported.
1024 auto memrefType = llvm::dyn_cast<MemRefType>(getSource().getType());
1025 if (!memrefType)
1026 return {};
1027
1028 // Out of bound indices produce undefined behavior but are still valid IR.
1029 // Don't choke on them.
1030 int64_t indexVal = index.getInt();
1031 if (indexVal < 0 || indexVal >= memrefType.getRank())
1032 return {};
1033
1034 // Fold if the shape extent along the given index is known.
1035 if (!memrefType.isDynamicDim(index.getInt())) {
1036 Builder builder(getContext());
1037 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
1038 }
1039
1040 // The size at the given index is now known to be a dynamic size.
1041 unsigned unsignedIndex = index.getValue().getZExtValue();
1042
1043 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
1044 Operation *definingOp = getSource().getDefiningOp();
1045
1046 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
1047 return *(alloc.getDynamicSizes().begin() +
1048 memrefType.getDynamicDimIndex(unsignedIndex));
1049
1050 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
1051 return *(alloca.getDynamicSizes().begin() +
1052 memrefType.getDynamicDimIndex(unsignedIndex));
1053
1054 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
1055 return *(view.getDynamicSizes().begin() +
1056 memrefType.getDynamicDimIndex(unsignedIndex));
1057
1058 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
1059 llvm::SmallBitVector unusedDims = subview.getDroppedDims();
1060 unsigned resultIndex = 0;
1061 unsigned sourceRank = subview.getSourceType().getRank();
1062 unsigned sourceIndex = 0;
1063 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
1064 if (unusedDims.test(i))
1065 continue;
1066 if (resultIndex == unsignedIndex) {
1067 sourceIndex = i;
1068 break;
1069 }
1070 resultIndex++;
1071 }
1072 assert(subview.isDynamicSize(sourceIndex) &&
1073 "expected dynamic subview size");
1074 return subview.getDynamicSize(sourceIndex);
1075 }
1076
1077 // dim(memrefcast) -> dim
1078 if (succeeded(foldMemRefCast(*this)))
1079 return getResult();
1080
1081 return {};
1082}
1083
1084namespace {
1085/// Fold dim of a memref reshape operation to a load into the reshape's shape
1086/// operand.
1087struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
1088 using OpRewritePattern<DimOp>::OpRewritePattern;
1089
1090 LogicalResult matchAndRewrite(DimOp dim,
1091 PatternRewriter &rewriter) const override {
1092 auto reshape = dim.getSource().getDefiningOp<ReshapeOp>();
1093
1094 if (!reshape)
1095 return rewriter.notifyMatchFailure(
1096 dim, "Dim op is not defined by a reshape op.");
1097
1098 // dim of a memref reshape can be folded if dim.getIndex() dominates the
1099 // reshape. Instead of using `DominanceInfo` (which is usually costly) we
1100 // cheaply check that either of the following conditions hold:
1101 // 1. dim.getIndex() is defined in the same block as reshape but before
1102 // reshape.
1103 // 2. dim.getIndex() is defined in a parent block of
1104 // reshape.
1105
1106 // Check condition 1
1107 if (dim.getIndex().getParentBlock() == reshape->getBlock()) {
1108 if (auto *definingOp = dim.getIndex().getDefiningOp()) {
1109 if (reshape->isBeforeInBlock(definingOp)) {
1110 return rewriter.notifyMatchFailure(
1111 dim,
1112 "dim.getIndex is not defined before reshape in the same block.");
1113 }
1114 } // else dim.getIndex is a block argument to reshape->getBlock and
1115 // dominates reshape
1116 } // Check condition 2
1117 else if (dim->getBlock() != reshape->getBlock() &&
1118 !dim.getIndex().getParentRegion()->isProperAncestor(
1119 reshape->getParentRegion())) {
1120 // If dim and reshape are in the same block but dim.getIndex() isn't, we
1121 // already know dim.getIndex() dominates reshape without calling
1122 // `isProperAncestor`
1123 return rewriter.notifyMatchFailure(
1124 dim, "dim.getIndex does not dominate reshape.");
1125 }
1126
1127 // Place the load directly after the reshape to ensure that the shape memref
1128 // was not mutated.
1129 rewriter.setInsertionPointAfter(reshape);
1130 Location loc = dim.getLoc();
1131 Value load =
1132 LoadOp::create(rewriter, loc, reshape.getShape(), dim.getIndex());
1133 if (load.getType() != dim.getType())
1134 load = arith::IndexCastOp::create(rewriter, loc, dim.getType(), load);
1135 rewriter.replaceOp(dim, load);
1136 return success();
1137 }
1138};
1139
1140} // namespace
1141
1142void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1143 MLIRContext *context) {
1144 results.add<DimOfMemRefReshape>(context);
1145}
1146
1147// ---------------------------------------------------------------------------
1148// DmaStartOp
1149// ---------------------------------------------------------------------------
1150
1151void DmaStartOp::build(OpBuilder &builder, OperationState &result,
1152 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
1153 ValueRange destIndices, Value numElements,
1154 Value tagMemRef, ValueRange tagIndices, Value stride,
1155 Value elementsPerStride) {
1156 result.addOperands(srcMemRef);
1157 result.addOperands(srcIndices);
1158 result.addOperands(destMemRef);
1159 result.addOperands(destIndices);
1160 result.addOperands({numElements, tagMemRef});
1161 result.addOperands(tagIndices);
1162 if (stride)
1163 result.addOperands({stride, elementsPerStride});
1164}
1165
1166void DmaStartOp::print(OpAsmPrinter &p) {
1167 p << " " << getSrcMemRef() << '[' << getSrcIndices() << "], "
1168 << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements()
1169 << ", " << getTagMemRef() << '[' << getTagIndices() << ']';
1170 if (isStrided())
1171 p << ", " << getStride() << ", " << getNumElementsPerStride();
1172
1173 p.printOptionalAttrDict((*this)->getAttrs());
1174 p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType()
1175 << ", " << getTagMemRef().getType();
1176}
1177
1178// Parse DmaStartOp.
1179// Ex:
1180// %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
1181// %tag[%index], %stride, %num_elt_per_stride :
1182// : memref<3076 x f32, 0>,
1183// memref<1024 x f32, 2>,
1184// memref<1 x i32>
1185//
1186ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) {
1187 OpAsmParser::UnresolvedOperand srcMemRefInfo;
1189 OpAsmParser::UnresolvedOperand dstMemRefInfo;
1191 OpAsmParser::UnresolvedOperand numElementsInfo;
1192 OpAsmParser::UnresolvedOperand tagMemrefInfo;
1195
1197 auto indexType = parser.getBuilder().getIndexType();
1198
1199 // Parse and resolve the following list of operands:
1200 // *) source memref followed by its indices (in square brackets).
1201 // *) destination memref followed by its indices (in square brackets).
1202 // *) dma size in KiB.
1203 if (parser.parseOperand(srcMemRefInfo) ||
1204 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
1205 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1206 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
1207 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1208 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
1209 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
1210 return failure();
1211
1212 // Parse optional stride and elements per stride.
1213 if (parser.parseTrailingOperandList(strideInfo))
1214 return failure();
1215
1216 bool isStrided = strideInfo.size() == 2;
1217 if (!strideInfo.empty() && !isStrided) {
1218 return parser.emitError(parser.getNameLoc(),
1219 "expected two stride related operands");
1220 }
1221
1222 if (parser.parseColonTypeList(types))
1223 return failure();
1224 if (types.size() != 3)
1225 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
1226
1227 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1228 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
1229 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1230 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
1231 // size should be an index.
1232 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
1233 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1234 // tag indices should be index.
1235 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1236 return failure();
1237
1238 if (isStrided) {
1239 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1240 return failure();
1241 }
1242
1243 return success();
1244}
1245
1246LogicalResult DmaStartOp::verify() {
1247 unsigned numOperands = getNumOperands();
1248
1249 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1250 // the number of elements.
1251 if (numOperands < 4)
1252 return emitOpError("expected at least 4 operands");
1253
1254 // Check types of operands. The order of these calls is important: the later
1255 // calls rely on some type properties to compute the operand position.
1256 // 1. Source memref.
1257 if (!llvm::isa<MemRefType>(getSrcMemRef().getType()))
1258 return emitOpError("expected source to be of memref type");
1259 if (numOperands < getSrcMemRefRank() + 4)
1260 return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
1261 << " operands";
1262 if (!getSrcIndices().empty() &&
1263 !llvm::all_of(getSrcIndices().getTypes(),
1264 [](Type t) { return t.isIndex(); }))
1265 return emitOpError("expected source indices to be of index type");
1266
1267 // 2. Destination memref.
1268 if (!llvm::isa<MemRefType>(getDstMemRef().getType()))
1269 return emitOpError("expected destination to be of memref type");
1270 unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
1271 if (numOperands < numExpectedOperands)
1272 return emitOpError() << "expected at least " << numExpectedOperands
1273 << " operands";
1274 if (!getDstIndices().empty() &&
1275 !llvm::all_of(getDstIndices().getTypes(),
1276 [](Type t) { return t.isIndex(); }))
1277 return emitOpError("expected destination indices to be of index type");
1278
1279 // 3. Number of elements.
1280 if (!getNumElements().getType().isIndex())
1281 return emitOpError("expected num elements to be of index type");
1282
1283 // 4. Tag memref.
1284 if (!llvm::isa<MemRefType>(getTagMemRef().getType()))
1285 return emitOpError("expected tag to be of memref type");
1286 numExpectedOperands += getTagMemRefRank();
1287 if (numOperands < numExpectedOperands)
1288 return emitOpError() << "expected at least " << numExpectedOperands
1289 << " operands";
1290 if (!getTagIndices().empty() &&
1291 !llvm::all_of(getTagIndices().getTypes(),
1292 [](Type t) { return t.isIndex(); }))
1293 return emitOpError("expected tag indices to be of index type");
1294
1295 // Optional stride-related operands must be either both present or both
1296 // absent.
1297 if (numOperands != numExpectedOperands &&
1298 numOperands != numExpectedOperands + 2)
1299 return emitOpError("incorrect number of operands");
1300
1301 // 5. Strides.
1302 if (isStrided()) {
1303 if (!getStride().getType().isIndex() ||
1304 !getNumElementsPerStride().getType().isIndex())
1305 return emitOpError(
1306 "expected stride and num elements per stride to be of type index");
1307 }
1308
1309 return success();
1310}
1311
1312LogicalResult DmaStartOp::fold(FoldAdaptor adaptor,
1314 /// dma_start(memrefcast) -> dma_start
1315 return foldMemRefCast(*this);
1316}
1317
1318// ---------------------------------------------------------------------------
1319// DmaWaitOp
1320// ---------------------------------------------------------------------------
1321
1322LogicalResult DmaWaitOp::fold(FoldAdaptor adaptor,
1324 /// dma_wait(memrefcast) -> dma_wait
1325 return foldMemRefCast(*this);
1326}
1327
1328LogicalResult DmaWaitOp::verify() {
1329 // Check that the number of tag indices matches the tagMemRef rank.
1330 unsigned numTagIndices = getTagIndices().size();
1331 unsigned tagMemRefRank = getTagMemRefRank();
1332 if (numTagIndices != tagMemRefRank)
1333 return emitOpError() << "expected tagIndices to have the same number of "
1334 "elements as the tagMemRef rank, expected "
1335 << tagMemRefRank << ", but got " << numTagIndices;
1336 return success();
1337}
1338
1339//===----------------------------------------------------------------------===//
1340// ExtractAlignedPointerAsIndexOp
1341//===----------------------------------------------------------------------===//
1342
1343void ExtractAlignedPointerAsIndexOp::getAsmResultNames(
1344 function_ref<void(Value, StringRef)> setNameFn) {
1345 setNameFn(getResult(), "intptr");
1346}
1347
1348//===----------------------------------------------------------------------===//
1349// ExtractStridedMetadataOp
1350//===----------------------------------------------------------------------===//
1351
1352/// The number and type of the results are inferred from the
1353/// shape of the source.
1354LogicalResult ExtractStridedMetadataOp::inferReturnTypes(
1355 MLIRContext *context, std::optional<Location> location,
1356 ExtractStridedMetadataOp::Adaptor adaptor,
1357 SmallVectorImpl<Type> &inferredReturnTypes) {
1358 auto sourceType = llvm::dyn_cast<MemRefType>(adaptor.getSource().getType());
1359 if (!sourceType)
1360 return failure();
1361
1362 unsigned sourceRank = sourceType.getRank();
1363 IndexType indexType = IndexType::get(context);
1364 auto memrefType =
1365 MemRefType::get({}, sourceType.getElementType(),
1366 MemRefLayoutAttrInterface{}, sourceType.getMemorySpace());
1367 // Base.
1368 inferredReturnTypes.push_back(memrefType);
1369 // Offset.
1370 inferredReturnTypes.push_back(indexType);
1371 // Sizes and strides.
1372 for (unsigned i = 0; i < sourceRank * 2; ++i)
1373 inferredReturnTypes.push_back(indexType);
1374 return success();
1375}
1376
1377void ExtractStridedMetadataOp::getAsmResultNames(
1378 function_ref<void(Value, StringRef)> setNameFn) {
1379 setNameFn(getBaseBuffer(), "base_buffer");
1380 setNameFn(getOffset(), "offset");
1381 // For multi-result to work properly with pretty names and packed syntax `x:3`
1382 // we can only give a pretty name to the first value in the pack.
1383 if (!getSizes().empty()) {
1384 setNameFn(getSizes().front(), "sizes");
1385 setNameFn(getStrides().front(), "strides");
1386 }
1387}
1388
1389/// Helper function to perform the replacement of all constant uses of `values`
1390/// by a materialized constant extracted from `maybeConstants`.
1391/// `values` and `maybeConstants` are expected to have the same size.
1392template <typename Container>
1393static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc,
1394 Container values,
1395 ArrayRef<OpFoldResult> maybeConstants) {
1396 assert(values.size() == maybeConstants.size() &&
1397 " expected values and maybeConstants of the same size");
1398 bool atLeastOneReplacement = false;
1399 for (auto [maybeConstant, result] : llvm::zip(maybeConstants, values)) {
1400 // Don't materialize a constant if there are no uses: this would indice
1401 // infinite loops in the driver.
1402 if (result.use_empty() || maybeConstant == getAsOpFoldResult(result))
1403 continue;
1404 assert(isa<Attribute>(maybeConstant) &&
1405 "The constified value should be either unchanged (i.e., == result) "
1406 "or a constant");
1408 rewriter, loc,
1409 llvm::cast<IntegerAttr>(cast<Attribute>(maybeConstant)).getInt());
1410 for (Operation *op : llvm::make_early_inc_range(result.getUsers())) {
1411 // modifyOpInPlace: lambda cannot capture structured bindings in C++17
1412 // yet.
1413 op->replaceUsesOfWith(result, constantVal);
1414 atLeastOneReplacement = true;
1415 }
1416 }
1417 return atLeastOneReplacement;
1418}
1419
1420LogicalResult
1421ExtractStridedMetadataOp::fold(FoldAdaptor adaptor,
1423 OpBuilder builder(*this);
1424
1425 bool atLeastOneReplacement = replaceConstantUsesOf(
1426 builder, getLoc(), ArrayRef<TypedValue<IndexType>>(getOffset()),
1427 getConstifiedMixedOffset());
1428 atLeastOneReplacement |= replaceConstantUsesOf(builder, getLoc(), getSizes(),
1429 getConstifiedMixedSizes());
1430 atLeastOneReplacement |= replaceConstantUsesOf(
1431 builder, getLoc(), getStrides(), getConstifiedMixedStrides());
1432
1433 // extract_strided_metadata(cast(x)) -> extract_strided_metadata(x).
1434 if (auto prev = getSource().getDefiningOp<CastOp>())
1435 if (isa<MemRefType>(prev.getSource().getType())) {
1436 getSourceMutable().assign(prev.getSource());
1437 atLeastOneReplacement = true;
1438 }
1439
1440 return success(atLeastOneReplacement);
1441}
1442
1443SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() {
1444 SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes());
1445 constifyIndexValues(values, getSource().getType().getShape());
1446 return values;
1447}
1448
1450ExtractStridedMetadataOp::getConstifiedMixedStrides() {
1451 SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides());
1452 SmallVector<int64_t> staticValues;
1453 int64_t unused;
1454 LogicalResult status =
1455 getSource().getType().getStridesAndOffset(staticValues, unused);
1456 (void)status;
1457 assert(succeeded(status) && "could not get strides from type");
1458 constifyIndexValues(values, staticValues);
1459 return values;
1460}
1461
1462OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() {
1463 OpFoldResult offsetOfr = getAsOpFoldResult(getOffset());
1464 SmallVector<OpFoldResult> values(1, offsetOfr);
1465 SmallVector<int64_t> staticValues, unused;
1466 int64_t offset;
1467 LogicalResult status =
1468 getSource().getType().getStridesAndOffset(unused, offset);
1469 (void)status;
1470 assert(succeeded(status) && "could not get offset from type");
1471 staticValues.push_back(offset);
1472 constifyIndexValues(values, staticValues);
1473 return values[0];
1474}
1475
1476//===----------------------------------------------------------------------===//
1477// GenericAtomicRMWOp
1478//===----------------------------------------------------------------------===//
1479
1480void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
1481 Value memref, ValueRange ivs) {
1482 OpBuilder::InsertionGuard g(builder);
1483 result.addOperands(memref);
1484 result.addOperands(ivs);
1485
1486 if (auto memrefType = llvm::dyn_cast<MemRefType>(memref.getType())) {
1487 Type elementType = memrefType.getElementType();
1488 result.addTypes(elementType);
1489
1490 Region *bodyRegion = result.addRegion();
1491 builder.createBlock(bodyRegion);
1492 bodyRegion->addArgument(elementType, memref.getLoc());
1493 }
1494}
1495
1496LogicalResult GenericAtomicRMWOp::verify() {
1497 auto &body = getRegion();
1498 if (body.getNumArguments() != 1)
1499 return emitOpError("expected single number of entry block arguments");
1500
1501 if (getResult().getType() != body.getArgument(0).getType())
1502 return emitOpError("expected block argument of the same type result type");
1503
1504 bool hasSideEffects =
1505 body.walk([&](Operation *nestedOp) {
1506 if (isMemoryEffectFree(nestedOp))
1507 return WalkResult::advance();
1508 nestedOp->emitError(
1509 "body of 'memref.generic_atomic_rmw' should contain "
1510 "only operations with no side effects");
1511 return WalkResult::interrupt();
1512 })
1513 .wasInterrupted();
1514 return hasSideEffects ? failure() : success();
1515}
1516
1517ParseResult GenericAtomicRMWOp::parse(OpAsmParser &parser,
1520 Type memrefType;
1522
1523 Type indexType = parser.getBuilder().getIndexType();
1524 if (parser.parseOperand(memref) ||
1526 parser.parseColonType(memrefType) ||
1527 parser.resolveOperand(memref, memrefType, result.operands) ||
1528 parser.resolveOperands(ivs, indexType, result.operands))
1529 return failure();
1530
1531 Region *body = result.addRegion();
1532 if (parser.parseRegion(*body, {}) ||
1533 parser.parseOptionalAttrDict(result.attributes))
1534 return failure();
1535 result.types.push_back(llvm::cast<MemRefType>(memrefType).getElementType());
1536 return success();
1537}
1538
1539void GenericAtomicRMWOp::print(OpAsmPrinter &p) {
1540 p << ' ' << getMemref() << "[" << getIndices()
1541 << "] : " << getMemref().getType() << ' ';
1542 p.printRegion(getRegion());
1543 p.printOptionalAttrDict((*this)->getAttrs());
1544}
1545
1546//===----------------------------------------------------------------------===//
1547// AtomicYieldOp
1548//===----------------------------------------------------------------------===//
1549
1550LogicalResult AtomicYieldOp::verify() {
1551 Type parentType = (*this)->getParentOp()->getResultTypes().front();
1552 Type resultType = getResult().getType();
1553 if (parentType != resultType)
1554 return emitOpError() << "types mismatch between yield op: " << resultType
1555 << " and its parent: " << parentType;
1556 return success();
1557}
1558
1559//===----------------------------------------------------------------------===//
1560// GlobalOp
1561//===----------------------------------------------------------------------===//
1562
1564 TypeAttr type,
1565 Attribute initialValue) {
1566 p << type;
1567 if (!op.isExternal()) {
1568 p << " = ";
1569 if (op.isUninitialized())
1570 p << "uninitialized";
1571 else
1572 p.printAttributeWithoutType(initialValue);
1573 }
1574}
1575
1576static ParseResult
1578 Attribute &initialValue) {
1579 Type type;
1580 if (parser.parseType(type))
1581 return failure();
1582
1583 auto memrefType = llvm::dyn_cast<MemRefType>(type);
1584 if (!memrefType || !memrefType.hasStaticShape())
1585 return parser.emitError(parser.getNameLoc())
1586 << "type should be static shaped memref, but got " << type;
1587 typeAttr = TypeAttr::get(type);
1588
1589 if (parser.parseOptionalEqual())
1590 return success();
1591
1592 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1593 initialValue = UnitAttr::get(parser.getContext());
1594 return success();
1595 }
1596
1597 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1598 if (parser.parseAttribute(initialValue, tensorType))
1599 return failure();
1600 if (!llvm::isa<ElementsAttr>(initialValue))
1601 return parser.emitError(parser.getNameLoc())
1602 << "initial value should be a unit or elements attribute";
1603 return success();
1604}
1605
1606LogicalResult GlobalOp::verify() {
1607 auto memrefType = llvm::dyn_cast<MemRefType>(getType());
1608 if (!memrefType || !memrefType.hasStaticShape())
1609 return emitOpError("type should be static shaped memref, but got ")
1610 << getType();
1611
1612 // Verify that the initial value, if present, is either a unit attribute or
1613 // an elements attribute.
1614 if (getInitialValue().has_value()) {
1615 Attribute initValue = getInitialValue().value();
1616 if (!llvm::isa<UnitAttr>(initValue) && !llvm::isa<ElementsAttr>(initValue))
1617 return emitOpError("initial value should be a unit or elements "
1618 "attribute, but got ")
1619 << initValue;
1620
1621 // Check that the type of the initial value is compatible with the type of
1622 // the global variable.
1623 if (auto elementsAttr = llvm::dyn_cast<ElementsAttr>(initValue)) {
1624 // Check the element types match.
1625 auto initElementType =
1626 cast<TensorType>(elementsAttr.getType()).getElementType();
1627 auto memrefElementType = memrefType.getElementType();
1628
1629 if (initElementType != memrefElementType)
1630 return emitOpError("initial value element expected to be of type ")
1631 << memrefElementType << ", but was of type " << initElementType;
1632
1633 // Check the shapes match, given that memref globals can only produce
1634 // statically shaped memrefs and elements literal type must have a static
1635 // shape we can assume both types are shaped.
1636 auto initShape = elementsAttr.getShapedType().getShape();
1637 auto memrefShape = memrefType.getShape();
1638 if (initShape != memrefShape)
1639 return emitOpError("initial value shape expected to be ")
1640 << memrefShape << " but was " << initShape;
1641 }
1642 }
1643
1644 // TODO: verify visibility for declarations.
1645 return success();
1646}
1647
1648ElementsAttr GlobalOp::getConstantInitValue() {
1649 auto initVal = getInitialValue();
1650 if (getConstant() && initVal.has_value())
1651 return llvm::cast<ElementsAttr>(initVal.value());
1652 return {};
1653}
1654
1655//===----------------------------------------------------------------------===//
1656// GetGlobalOp
1657//===----------------------------------------------------------------------===//
1658
1659LogicalResult
1660GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1661 // Verify that the result type is same as the type of the referenced
1662 // memref.global op.
1663 auto global =
1664 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, getNameAttr());
1665 if (!global)
1666 return emitOpError("'")
1667 << getName() << "' does not reference a valid global memref";
1668
1669 Type resultType = getResult().getType();
1670 if (global.getType() != resultType)
1671 return emitOpError("result type ")
1672 << resultType << " does not match type " << global.getType()
1673 << " of the global memref @" << getName();
1674 return success();
1675}
1676
1677//===----------------------------------------------------------------------===//
1678// LoadOp
1679//===----------------------------------------------------------------------===//
1680
1681LogicalResult LoadOp::verify() {
1682 if (static_cast<int64_t>(getIndices().size()) != getMemRefType().getRank()) {
1683 return emitOpError("incorrect number of indices for load, expected ")
1684 << getMemRefType().getRank() << " but got " << getIndices().size();
1685 }
1686 return success();
1687}
1688
1689OpFoldResult LoadOp::fold(FoldAdaptor adaptor) {
1690 /// load(memrefcast) -> load
1691 if (succeeded(foldMemRefCast(*this)))
1692 return getResult();
1693 return OpFoldResult();
1694}
1695
1696FailureOr<std::optional<SmallVector<Value>>>
1697LoadOp::bubbleDownCasts(OpBuilder &builder) {
1699 getResult());
1700}
1701
1702//===----------------------------------------------------------------------===//
1703// MemorySpaceCastOp
1704//===----------------------------------------------------------------------===//
1705
1706void MemorySpaceCastOp::getAsmResultNames(
1707 function_ref<void(Value, StringRef)> setNameFn) {
1708 setNameFn(getResult(), "memspacecast");
1709}
1710
1711bool MemorySpaceCastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
1712 if (inputs.size() != 1 || outputs.size() != 1)
1713 return false;
1714 Type a = inputs.front(), b = outputs.front();
1715 auto aT = llvm::dyn_cast<MemRefType>(a);
1716 auto bT = llvm::dyn_cast<MemRefType>(b);
1717
1718 auto uaT = llvm::dyn_cast<UnrankedMemRefType>(a);
1719 auto ubT = llvm::dyn_cast<UnrankedMemRefType>(b);
1720
1721 if (aT && bT) {
1722 if (aT.getElementType() != bT.getElementType())
1723 return false;
1724 if (aT.getLayout() != bT.getLayout())
1725 return false;
1726 if (aT.getShape() != bT.getShape())
1727 return false;
1728 return true;
1729 }
1730 if (uaT && ubT) {
1731 return uaT.getElementType() == ubT.getElementType();
1732 }
1733 return false;
1734}
1735
1736OpFoldResult MemorySpaceCastOp::fold(FoldAdaptor adaptor) {
1737 // memory_space_cast(memory_space_cast(v, t1), t2) -> memory_space_cast(v,
1738 // t2)
1739 if (auto parentCast = getSource().getDefiningOp<MemorySpaceCastOp>()) {
1740 getSourceMutable().assign(parentCast.getSource());
1741 return getResult();
1742 }
1743 return Value{};
1744}
1745
1746TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getSourcePtr() {
1747 return getSource();
1748}
1749
1750TypedValue<PtrLikeTypeInterface> MemorySpaceCastOp::getTargetPtr() {
1751 return getDest();
1752}
1753
1754bool MemorySpaceCastOp::isValidMemorySpaceCast(PtrLikeTypeInterface tgt,
1755 PtrLikeTypeInterface src) {
1756 return isa<BaseMemRefType>(tgt) &&
1757 tgt.clonePtrWith(src.getMemorySpace(), std::nullopt) == src;
1758}
1759
1760MemorySpaceCastOpInterface MemorySpaceCastOp::cloneMemorySpaceCastOp(
1761 OpBuilder &b, PtrLikeTypeInterface tgt,
1763 assert(isValidMemorySpaceCast(tgt, src.getType()) && "invalid arguments");
1764 return MemorySpaceCastOp::create(b, getLoc(), tgt, src);
1765}
1766
1767/// The only cast we recognize as promotable is to the generic space.
1768bool MemorySpaceCastOp::isSourcePromotable() {
1769 return getDest().getType().getMemorySpace() == nullptr;
1770}
1771
1772//===----------------------------------------------------------------------===//
1773// PrefetchOp
1774//===----------------------------------------------------------------------===//
1775
1776void PrefetchOp::print(OpAsmPrinter &p) {
1777 p << " " << getMemref() << '[';
1779 p << ']' << ", " << (getIsWrite() ? "write" : "read");
1780 p << ", locality<" << getLocalityHint();
1781 p << ">, " << (getIsDataCache() ? "data" : "instr");
1783 (*this)->getAttrs(),
1784 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1785 p << " : " << getMemRefType();
1786}
1787
1788ParseResult PrefetchOp::parse(OpAsmParser &parser, OperationState &result) {
1791 IntegerAttr localityHint;
1792 MemRefType type;
1793 StringRef readOrWrite, cacheType;
1794
1795 auto indexTy = parser.getBuilder().getIndexType();
1796 auto i32Type = parser.getBuilder().getIntegerType(32);
1797 if (parser.parseOperand(memrefInfo) ||
1799 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1800 parser.parseComma() || parser.parseKeyword("locality") ||
1801 parser.parseLess() ||
1802 parser.parseAttribute(localityHint, i32Type, "localityHint",
1803 result.attributes) ||
1804 parser.parseGreater() || parser.parseComma() ||
1805 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1806 parser.resolveOperand(memrefInfo, type, result.operands) ||
1807 parser.resolveOperands(indexInfo, indexTy, result.operands))
1808 return failure();
1809
1810 if (readOrWrite != "read" && readOrWrite != "write")
1811 return parser.emitError(parser.getNameLoc(),
1812 "rw specifier has to be 'read' or 'write'");
1813 result.addAttribute(PrefetchOp::getIsWriteAttrStrName(),
1814 parser.getBuilder().getBoolAttr(readOrWrite == "write"));
1815
1816 if (cacheType != "data" && cacheType != "instr")
1817 return parser.emitError(parser.getNameLoc(),
1818 "cache type has to be 'data' or 'instr'");
1819
1820 result.addAttribute(PrefetchOp::getIsDataCacheAttrStrName(),
1821 parser.getBuilder().getBoolAttr(cacheType == "data"));
1822
1823 return success();
1824}
1825
1826LogicalResult PrefetchOp::verify() {
1827 if (getNumOperands() != 1 + getMemRefType().getRank())
1828 return emitOpError("too few indices");
1829
1830 return success();
1831}
1832
1833LogicalResult PrefetchOp::fold(FoldAdaptor adaptor,
1835 // prefetch(memrefcast) -> prefetch
1836 return foldMemRefCast(*this);
1837}
1838
1839//===----------------------------------------------------------------------===//
1840// RankOp
1841//===----------------------------------------------------------------------===//
1842
1843OpFoldResult RankOp::fold(FoldAdaptor adaptor) {
1844 // Constant fold rank when the rank of the operand is known.
1845 auto type = getOperand().getType();
1846 auto shapedType = llvm::dyn_cast<ShapedType>(type);
1847 if (shapedType && shapedType.hasRank())
1848 return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank());
1849 return IntegerAttr();
1850}
1851
1852//===----------------------------------------------------------------------===//
1853// ReinterpretCastOp
1854//===----------------------------------------------------------------------===//
1855
1856void ReinterpretCastOp::getAsmResultNames(
1857 function_ref<void(Value, StringRef)> setNameFn) {
1858 setNameFn(getResult(), "reinterpret_cast");
1859}
1860
1861/// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1862/// `staticSizes` and `staticStrides` are automatically filled with
1863/// source-memref-rank sentinel values that encode dynamic entries.
1864void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1865 MemRefType resultType, Value source,
1867 ArrayRef<OpFoldResult> strides,
1869 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1870 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1871 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1872 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1873 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1874 result.addAttributes(attrs);
1875 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1876 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
1877 b.getDenseI64ArrayAttr(staticSizes),
1878 b.getDenseI64ArrayAttr(staticStrides));
1879}
1880
1881void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1882 Value source, OpFoldResult offset,
1884 ArrayRef<OpFoldResult> strides,
1886 auto sourceType = cast<BaseMemRefType>(source.getType());
1887 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1888 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1889 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets);
1890 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
1891 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
1892 auto stridedLayout = StridedLayoutAttr::get(
1893 b.getContext(), staticOffsets.front(), staticStrides);
1894 auto resultType = MemRefType::get(staticSizes, sourceType.getElementType(),
1895 stridedLayout, sourceType.getMemorySpace());
1896 build(b, result, resultType, source, offset, sizes, strides, attrs);
1897}
1898
1899void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1900 MemRefType resultType, Value source,
1901 int64_t offset, ArrayRef<int64_t> sizes,
1902 ArrayRef<int64_t> strides,
1904 SmallVector<OpFoldResult> sizeValues =
1905 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1906 return b.getI64IntegerAttr(v);
1907 }));
1908 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1909 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1910 return b.getI64IntegerAttr(v);
1911 }));
1912 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1913 strideValues, attrs);
1914}
1915
1916void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1917 MemRefType resultType, Value source, Value offset,
1918 ValueRange sizes, ValueRange strides,
1920 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1921 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1922 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1923 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1924 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1925}
1926
1927// TODO: ponder whether we want to allow missing trailing sizes/strides that are
1928// completed automatically, like we have for subview and extract_slice.
1929LogicalResult ReinterpretCastOp::verify() {
1930 // The source and result memrefs should be in the same memory space.
1931 auto srcType = llvm::cast<BaseMemRefType>(getSource().getType());
1932 auto resultType = llvm::cast<MemRefType>(getType());
1933 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1934 return emitError("different memory spaces specified for source type ")
1935 << srcType << " and result memref type " << resultType;
1936 if (srcType.getElementType() != resultType.getElementType())
1937 return emitError("different element types specified for source type ")
1938 << srcType << " and result memref type " << resultType;
1939
1940 // Match sizes in result memref type and in static_sizes attribute.
1941 for (auto [idx, resultSize, expectedSize] :
1942 llvm::enumerate(resultType.getShape(), getStaticSizes())) {
1943 if (ShapedType::isStatic(resultSize) && resultSize != expectedSize)
1944 return emitError("expected result type with size = ")
1945 << (ShapedType::isDynamic(expectedSize)
1946 ? std::string("dynamic")
1947 : std::to_string(expectedSize))
1948 << " instead of " << resultSize << " in dim = " << idx;
1949 }
1950
1951 // Match offset and strides in static_offset and static_strides attributes. If
1952 // result memref type has no affine map specified, this will assume an
1953 // identity layout.
1954 int64_t resultOffset;
1955 SmallVector<int64_t, 4> resultStrides;
1956 if (failed(resultType.getStridesAndOffset(resultStrides, resultOffset)))
1957 return emitError("expected result type to have strided layout but found ")
1958 << resultType;
1959
1960 // Match offset in result memref type and in static_offsets attribute.
1961 int64_t expectedOffset = getStaticOffsets().front();
1962 if (ShapedType::isStatic(resultOffset) && resultOffset != expectedOffset)
1963 return emitError("expected result type with offset = ")
1964 << (ShapedType::isDynamic(expectedOffset)
1965 ? std::string("dynamic")
1966 : std::to_string(expectedOffset))
1967 << " instead of " << resultOffset;
1968
1969 // Match strides in result memref type and in static_strides attribute.
1970 for (auto [idx, resultStride, expectedStride] :
1971 llvm::enumerate(resultStrides, getStaticStrides())) {
1972 if (ShapedType::isStatic(resultStride) && resultStride != expectedStride)
1973 return emitError("expected result type with stride = ")
1974 << (ShapedType::isDynamic(expectedStride)
1975 ? std::string("dynamic")
1976 : std::to_string(expectedStride))
1977 << " instead of " << resultStride << " in dim = " << idx;
1978 }
1979
1980 return success();
1981}
1982
1983OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) {
1984 Value src = getSource();
1985 auto getPrevSrc = [&]() -> Value {
1986 // reinterpret_cast(reinterpret_cast(x)) -> reinterpret_cast(x).
1987 if (auto prev = src.getDefiningOp<ReinterpretCastOp>())
1988 return prev.getSource();
1989
1990 // reinterpret_cast(cast(x)) -> reinterpret_cast(x).
1991 if (auto prev = src.getDefiningOp<CastOp>())
1992 return prev.getSource();
1993
1994 // reinterpret_cast(subview(x)) -> reinterpret_cast(x) if subview offsets
1995 // are 0.
1996 if (auto prev = src.getDefiningOp<SubViewOp>())
1997 if (llvm::all_of(prev.getMixedOffsets(), isZeroInteger))
1998 return prev.getSource();
1999
2000 return nullptr;
2001 };
2002
2003 if (auto prevSrc = getPrevSrc()) {
2004 getSourceMutable().assign(prevSrc);
2005 return getResult();
2006 }
2007
2008 // reinterpret_cast(x) w/o offset/shape/stride changes -> x
2009 if (ShapedType::isStaticShape(getType().getShape()) &&
2010 src.getType() == getType() && getStaticOffsets().front() == 0) {
2011 return src;
2012 }
2013
2014 return nullptr;
2015}
2016
2017SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() {
2020 return values;
2021}
2022
2023SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() {
2024 SmallVector<OpFoldResult> values = getMixedStrides();
2025 SmallVector<int64_t> staticValues;
2026 int64_t unused;
2027 LogicalResult status = getType().getStridesAndOffset(staticValues, unused);
2028 (void)status;
2029 assert(succeeded(status) && "could not get strides from type");
2030 constifyIndexValues(values, staticValues);
2031 return values;
2032}
2033
2034OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() {
2035 SmallVector<OpFoldResult> values = getMixedOffsets();
2036 assert(values.size() == 1 &&
2037 "reinterpret_cast must have one and only one offset");
2038 SmallVector<int64_t> staticValues, unused;
2039 int64_t offset;
2040 LogicalResult status = getType().getStridesAndOffset(unused, offset);
2041 (void)status;
2042 assert(succeeded(status) && "could not get offset from type");
2043 staticValues.push_back(offset);
2044 constifyIndexValues(values, staticValues);
2045 return values[0];
2046}
2047
2048namespace {
2049/// Replace the sequence:
2050/// ```
2051/// base, offset, sizes, strides = extract_strided_metadata src
2052/// dst = reinterpret_cast base to offset, sizes, strides
2053/// ```
2054/// With
2055///
2056/// ```
2057/// dst = memref.cast src
2058/// ```
2059///
2060/// Note: The cast operation is only inserted when the type of dst and src
2061/// are not the same. E.g., when going from <4xf32> to <?xf32>.
2062///
2063/// This pattern also matches when the offset, sizes, and strides don't come
2064/// directly from the `extract_strided_metadata`'s results but it can be
2065/// statically proven that they would hold the same values.
2066///
2067/// For instance, the following sequence would be replaced:
2068/// ```
2069/// base, offset, sizes, strides =
2070/// extract_strided_metadata memref : memref<3x4xty>
2071/// dst = reinterpret_cast base to 0, [3, 4], strides
2072/// ```
2073/// Because we know (thanks to the type of the input memref) that variable
2074/// `offset` and `sizes` will respectively hold 0 and [3, 4].
2075///
2076/// Similarly, the following sequence would be replaced:
2077/// ```
2078/// c0 = arith.constant 0
2079/// c4 = arith.constant 4
2080/// base, offset, sizes, strides =
2081/// extract_strided_metadata memref : memref<3x4xty>
2082/// dst = reinterpret_cast base to c0, [3, c4], strides
2083/// ```
2084/// Because we know that `offset`and `c0` will hold 0
2085/// and `c4` will hold 4.
2086///
2087/// If the pattern above does not match, the input of the
2088/// extract_strided_metadata is always folded into the input of the
2089/// reinterpret_cast operator. This allows for dead code elimination to get rid
2090/// of the extract_strided_metadata in some cases.
2091struct ReinterpretCastOpExtractStridedMetadataFolder
2092 : public OpRewritePattern<ReinterpretCastOp> {
2093public:
2094 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2095
2096 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2097 PatternRewriter &rewriter) const override {
2098 auto extractStridedMetadata =
2099 op.getSource().getDefiningOp<ExtractStridedMetadataOp>();
2100 if (!extractStridedMetadata)
2101 return failure();
2102
2103 // Check if the reinterpret cast reconstructs a memref with the exact same
2104 // properties as the extract strided metadata.
2105 auto isReinterpretCastNoop = [&]() -> bool {
2106 // First, check that the strides are the same.
2107 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedStrides(),
2108 op.getConstifiedMixedStrides()))
2109 return false;
2110
2111 // Second, check the sizes.
2112 if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(),
2113 op.getConstifiedMixedSizes()))
2114 return false;
2115
2116 // Finally, check the offset.
2117 assert(op.getMixedOffsets().size() == 1 &&
2118 "reinterpret_cast with more than one offset should have been "
2119 "rejected by the verifier");
2120 return extractStridedMetadata.getConstifiedMixedOffset() ==
2121 op.getConstifiedMixedOffset();
2122 };
2123
2124 if (!isReinterpretCastNoop()) {
2125 // If the extract_strided_metadata / reinterpret_cast pair can't be
2126 // completely folded, then we could fold the input of the
2127 // extract_strided_metadata into the input of the reinterpret_cast
2128 // input. For some cases (e.g., static dimensions) the
2129 // the extract_strided_metadata is eliminated by dead code elimination.
2130 //
2131 // reinterpret_cast(extract_strided_metadata(x)) -> reinterpret_cast(x).
2132 //
2133 // We can always fold the input of a extract_strided_metadata operator
2134 // to the input of a reinterpret_cast operator, because they point to
2135 // the same memory. Note that the reinterpret_cast does not use the
2136 // layout of its input memref, only its base memory pointer which is
2137 // the same as the base pointer returned by the extract_strided_metadata
2138 // operator and the base pointer of the extract_strided_metadata memref
2139 // input.
2140 rewriter.modifyOpInPlace(op, [&]() {
2141 op.getSourceMutable().assign(extractStridedMetadata.getSource());
2142 });
2143 return success();
2144 }
2145
2146 // At this point, we know that the back and forth between extract strided
2147 // metadata and reinterpret cast is a noop. However, the final type of the
2148 // reinterpret cast may not be exactly the same as the original memref.
2149 // E.g., it could be changing a dimension from static to dynamic. Check that
2150 // here and add a cast if necessary.
2151 Type srcTy = extractStridedMetadata.getSource().getType();
2152 if (srcTy == op.getResult().getType())
2153 rewriter.replaceOp(op, extractStridedMetadata.getSource());
2154 else
2155 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(),
2156 extractStridedMetadata.getSource());
2157
2158 return success();
2159 }
2160};
2161
2162struct ReinterpretCastOpConstantFolder
2163 : public OpRewritePattern<ReinterpretCastOp> {
2164public:
2165 using OpRewritePattern<ReinterpretCastOp>::OpRewritePattern;
2166
2167 LogicalResult matchAndRewrite(ReinterpretCastOp op,
2168 PatternRewriter &rewriter) const override {
2169 unsigned srcStaticCount = llvm::count_if(
2170 llvm::concat<OpFoldResult>(op.getMixedOffsets(), op.getMixedSizes(),
2171 op.getMixedStrides()),
2172 [](OpFoldResult ofr) { return isa<Attribute>(ofr); });
2173
2174 SmallVector<OpFoldResult> offsets = {op.getConstifiedMixedOffset()};
2175 SmallVector<OpFoldResult> sizes = op.getConstifiedMixedSizes();
2176 SmallVector<OpFoldResult> strides = op.getConstifiedMixedStrides();
2177
2178 // TODO: Using counting comparison instead of direct comparison because
2179 // getMixedValues (and therefore ReinterpretCastOp::getMixed...) returns
2180 // IntegerAttrs, while constifyIndexValues (and therefore
2181 // ReinterpretCastOp::getConstifiedMixed...) returns IndexAttrs.
2182 if (srcStaticCount ==
2183 llvm::count_if(llvm::concat<OpFoldResult>(offsets, sizes, strides),
2184 [](OpFoldResult ofr) { return isa<Attribute>(ofr); }))
2185 return failure();
2186
2187 auto newReinterpretCast = ReinterpretCastOp::create(
2188 rewriter, op->getLoc(), op.getSource(), offsets[0], sizes, strides);
2189
2190 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newReinterpretCast);
2191 return success();
2192 }
2193};
2194} // namespace
2195
2196void ReinterpretCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
2197 MLIRContext *context) {
2198 results.add<ReinterpretCastOpExtractStridedMetadataFolder,
2199 ReinterpretCastOpConstantFolder>(context);
2200}
2201
2202FailureOr<std::optional<SmallVector<Value>>>
2203ReinterpretCastOp::bubbleDownCasts(OpBuilder &builder) {
2204 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2205}
2206
2207//===----------------------------------------------------------------------===//
2208// Reassociative reshape ops
2209//===----------------------------------------------------------------------===//
2210
2211void CollapseShapeOp::getAsmResultNames(
2212 function_ref<void(Value, StringRef)> setNameFn) {
2213 setNameFn(getResult(), "collapse_shape");
2214}
2215
2216void ExpandShapeOp::getAsmResultNames(
2217 function_ref<void(Value, StringRef)> setNameFn) {
2218 setNameFn(getResult(), "expand_shape");
2219}
2220
2221LogicalResult ExpandShapeOp::reifyResultShapes(
2222 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedResultShapes) {
2223 reifiedResultShapes = {
2224 getMixedValues(getStaticOutputShape(), getOutputShape(), builder)};
2225 return success();
2226}
2227
2228/// Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp
2229/// result and operand. Layout maps are verified separately.
2230///
2231/// If `allowMultipleDynamicDimsPerGroup`, multiple dynamic dimensions are
2232/// allowed in a reassocation group.
2233static LogicalResult
2235 ArrayRef<int64_t> expandedShape,
2236 ArrayRef<ReassociationIndices> reassociation,
2237 bool allowMultipleDynamicDimsPerGroup) {
2238 // There must be one reassociation group per collapsed dimension.
2239 if (collapsedShape.size() != reassociation.size())
2240 return op->emitOpError("invalid number of reassociation groups: found ")
2241 << reassociation.size() << ", expected " << collapsedShape.size();
2242
2243 // The next expected expanded dimension index (while iterating over
2244 // reassociation indices).
2245 int64_t nextDim = 0;
2246 for (const auto &it : llvm::enumerate(reassociation)) {
2247 ReassociationIndices group = it.value();
2248 int64_t collapsedDim = it.index();
2249
2250 bool foundDynamic = false;
2251 for (int64_t expandedDim : group) {
2252 if (expandedDim != nextDim++)
2253 return op->emitOpError("reassociation indices must be contiguous");
2254
2255 if (expandedDim >= static_cast<int64_t>(expandedShape.size()))
2256 return op->emitOpError("reassociation index ")
2257 << expandedDim << " is out of bounds";
2258
2259 // Check if there are multiple dynamic dims in a reassociation group.
2260 if (ShapedType::isDynamic(expandedShape[expandedDim])) {
2261 if (foundDynamic && !allowMultipleDynamicDimsPerGroup)
2262 return op->emitOpError(
2263 "at most one dimension in a reassociation group may be dynamic");
2264 foundDynamic = true;
2265 }
2266 }
2267
2268 // ExpandShapeOp/CollapseShapeOp may not be used to cast dynamicity.
2269 if (ShapedType::isDynamic(collapsedShape[collapsedDim]) != foundDynamic)
2270 return op->emitOpError("collapsed dim (")
2271 << collapsedDim
2272 << ") must be dynamic if and only if reassociation group is "
2273 "dynamic";
2274
2275 // If all dims in the reassociation group are static, the size of the
2276 // collapsed dim can be verified.
2277 if (!foundDynamic) {
2278 int64_t groupSize = 1;
2279 for (int64_t expandedDim : group)
2280 groupSize *= expandedShape[expandedDim];
2281 if (groupSize != collapsedShape[collapsedDim])
2282 return op->emitOpError("collapsed dim size (")
2283 << collapsedShape[collapsedDim]
2284 << ") must equal reassociation group size (" << groupSize << ")";
2285 }
2286 }
2287
2288 if (collapsedShape.empty()) {
2289 // Rank 0: All expanded dimensions must be 1.
2290 for (int64_t d : expandedShape)
2291 if (d != 1)
2292 return op->emitOpError(
2293 "rank 0 memrefs can only be extended/collapsed with/from ones");
2294 } else if (nextDim != static_cast<int64_t>(expandedShape.size())) {
2295 // Rank >= 1: Number of dimensions among all reassociation groups must match
2296 // the result memref rank.
2297 return op->emitOpError("expanded rank (")
2298 << expandedShape.size()
2299 << ") inconsistent with number of reassociation indices (" << nextDim
2300 << ")";
2301 }
2302
2303 return success();
2304}
2305
2306SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
2307 return getSymbolLessAffineMaps(getReassociationExprs());
2308}
2309
2310SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
2312 getReassociationIndices());
2313}
2314
2315SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
2316 return getSymbolLessAffineMaps(getReassociationExprs());
2317}
2318
2319SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
2321 getReassociationIndices());
2322}
2323
2324/// Compute the layout map after expanding a given source MemRef type with the
2325/// specified reassociation indices.
2326static FailureOr<StridedLayoutAttr>
2327computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
2328 ArrayRef<ReassociationIndices> reassociation) {
2329 int64_t srcOffset;
2330 SmallVector<int64_t> srcStrides;
2331 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2332 return failure();
2333 assert(srcStrides.size() == reassociation.size() && "invalid reassociation");
2334
2335 // 1-1 mapping between srcStrides and reassociation packs.
2336 // Each srcStride starts with the given value and gets expanded according to
2337 // the proper entries in resultShape.
2338 // Example:
2339 // srcStrides = [10000, 1 , 100 ],
2340 // reassociations = [ [0], [1], [2, 3, 4]],
2341 // resultSizes = [2, 5, 4, 3, 2] = [ [2], [5], [4, 3, 2]]
2342 // -> For the purpose of stride calculation, the useful sizes are:
2343 // [x, x, x, 3, 2] = [ [x], [x], [x, 3, 2]].
2344 // resultStrides = [10000, 1, 600, 200, 100]
2345 // Note that a stride does not get expanded along the first entry of each
2346 // shape pack.
2347 SmallVector<int64_t> reverseResultStrides;
2348 reverseResultStrides.reserve(resultShape.size());
2349 unsigned shapeIndex = resultShape.size() - 1;
2350 for (auto it : llvm::reverse(llvm::zip(reassociation, srcStrides))) {
2351 ReassociationIndices reassoc = std::get<0>(it);
2352 int64_t currentStrideToExpand = std::get<1>(it);
2353 for (unsigned idx = 0, e = reassoc.size(); idx < e; ++idx) {
2354 reverseResultStrides.push_back(currentStrideToExpand);
2355 currentStrideToExpand =
2356 (SaturatedInteger::wrap(currentStrideToExpand) *
2357 SaturatedInteger::wrap(resultShape[shapeIndex--]))
2358 .asInteger();
2359 }
2360 }
2361 auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
2362 resultStrides.resize(resultShape.size(), 1);
2363 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2364}
2365
2366FailureOr<MemRefType> ExpandShapeOp::computeExpandedType(
2367 MemRefType srcType, ArrayRef<int64_t> resultShape,
2368 ArrayRef<ReassociationIndices> reassociation) {
2369 if (srcType.getLayout().isIdentity()) {
2370 // If the source is contiguous (i.e., no layout map specified), so is the
2371 // result.
2372 MemRefLayoutAttrInterface layout;
2373 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2374 srcType.getMemorySpace());
2375 }
2376
2377 // Source may not be contiguous. Compute the layout map.
2378 FailureOr<StridedLayoutAttr> computedLayout =
2379 computeExpandedLayoutMap(srcType, resultShape, reassociation);
2380 if (failed(computedLayout))
2381 return failure();
2382 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2383 srcType.getMemorySpace());
2384}
2385
2386FailureOr<SmallVector<OpFoldResult>>
2387ExpandShapeOp::inferOutputShape(OpBuilder &b, Location loc,
2388 MemRefType expandedType,
2389 ArrayRef<ReassociationIndices> reassociation,
2390 ArrayRef<OpFoldResult> inputShape) {
2391 std::optional<SmallVector<OpFoldResult>> outputShape =
2392 inferExpandShapeOutputShape(b, loc, expandedType, reassociation,
2393 inputShape);
2394 if (!outputShape)
2395 return failure();
2396 return *outputShape;
2397}
2398
2399void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2400 Type resultType, Value src,
2401 ArrayRef<ReassociationIndices> reassociation,
2402 ArrayRef<OpFoldResult> outputShape) {
2403 auto [staticOutputShape, dynamicOutputShape] =
2404 decomposeMixedValues(SmallVector<OpFoldResult>(outputShape));
2405 build(builder, result, llvm::cast<MemRefType>(resultType), src,
2406 getReassociationIndicesAttribute(builder, reassociation),
2407 dynamicOutputShape, staticOutputShape);
2408}
2409
2410void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2411 Type resultType, Value src,
2412 ArrayRef<ReassociationIndices> reassociation) {
2413 SmallVector<OpFoldResult> inputShape =
2414 getMixedSizes(builder, result.location, src);
2415 MemRefType memrefResultTy = llvm::cast<MemRefType>(resultType);
2416 FailureOr<SmallVector<OpFoldResult>> outputShape = inferOutputShape(
2417 builder, result.location, memrefResultTy, reassociation, inputShape);
2418 // Failure of this assertion usually indicates presence of multiple
2419 // dynamic dimensions in the same reassociation group.
2420 assert(succeeded(outputShape) && "unable to infer output shape");
2421 build(builder, result, memrefResultTy, src, reassociation, *outputShape);
2422}
2423
2424void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2425 ArrayRef<int64_t> resultShape, Value src,
2426 ArrayRef<ReassociationIndices> reassociation) {
2427 // Only ranked memref source values are supported.
2428 auto srcType = llvm::cast<MemRefType>(src.getType());
2429 FailureOr<MemRefType> resultType =
2430 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2431 // Failure of this assertion usually indicates a problem with the source
2432 // type, e.g., could not get strides/offset.
2433 assert(succeeded(resultType) && "could not compute layout");
2434 build(builder, result, *resultType, src, reassociation);
2435}
2436
2437void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
2438 ArrayRef<int64_t> resultShape, Value src,
2439 ArrayRef<ReassociationIndices> reassociation,
2440 ArrayRef<OpFoldResult> outputShape) {
2441 // Only ranked memref source values are supported.
2442 auto srcType = llvm::cast<MemRefType>(src.getType());
2443 FailureOr<MemRefType> resultType =
2444 ExpandShapeOp::computeExpandedType(srcType, resultShape, reassociation);
2445 // Failure of this assertion usually indicates a problem with the source
2446 // type, e.g., could not get strides/offset.
2447 assert(succeeded(resultType) && "could not compute layout");
2448 build(builder, result, *resultType, src, reassociation, outputShape);
2449}
2450
2451LogicalResult ExpandShapeOp::verify() {
2452 MemRefType srcType = getSrcType();
2453 MemRefType resultType = getResultType();
2454
2455 if (srcType.getRank() > resultType.getRank()) {
2456 auto r0 = srcType.getRank();
2457 auto r1 = resultType.getRank();
2458 return emitOpError("has source rank ")
2459 << r0 << " and result rank " << r1 << ". This is not an expansion ("
2460 << r0 << " > " << r1 << ").";
2461 }
2462
2463 // Verify result shape.
2464 if (failed(verifyCollapsedShape(getOperation(), srcType.getShape(),
2465 resultType.getShape(),
2466 getReassociationIndices(),
2467 /*allowMultipleDynamicDimsPerGroup=*/true)))
2468 return failure();
2469
2470 // Compute expected result type (including layout map).
2471 FailureOr<MemRefType> expectedResultType = ExpandShapeOp::computeExpandedType(
2472 srcType, resultType.getShape(), getReassociationIndices());
2473 if (failed(expectedResultType))
2474 return emitOpError("invalid source layout map");
2475
2476 // Check actual result type.
2477 if (*expectedResultType != resultType)
2478 return emitOpError("expected expanded type to be ")
2479 << *expectedResultType << " but found " << resultType;
2480
2481 if ((int64_t)getStaticOutputShape().size() != resultType.getRank())
2482 return emitOpError("expected number of static shape bounds to be equal to "
2483 "the output rank (")
2484 << resultType.getRank() << ") but found "
2485 << getStaticOutputShape().size() << " inputs instead";
2486
2487 if ((int64_t)getOutputShape().size() !=
2488 llvm::count(getStaticOutputShape(), ShapedType::kDynamic))
2489 return emitOpError("mismatch in dynamic dims in output_shape and "
2490 "static_output_shape: static_output_shape has ")
2491 << llvm::count(getStaticOutputShape(), ShapedType::kDynamic)
2492 << " dynamic dims while output_shape has " << getOutputShape().size()
2493 << " values";
2494
2495 // Verify if provided output shapes are in agreement with output type.
2496 DenseI64ArrayAttr staticOutputShapes = getStaticOutputShapeAttr();
2497 ArrayRef<int64_t> resShape = getResult().getType().getShape();
2498 for (auto [pos, shape] : llvm::enumerate(resShape)) {
2499 if (ShapedType::isStatic(shape) && shape != staticOutputShapes[pos]) {
2500 return emitOpError("invalid output shape provided at pos ") << pos;
2501 }
2502 }
2503
2504 return success();
2505}
2506
2507void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2508 MLIRContext *context) {
2509 results.add<
2510 ComposeReassociativeReshapeOps<ExpandShapeOp, ReshapeOpKind::kExpand>,
2511 ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>>(context);
2512}
2513
2514FailureOr<std::optional<SmallVector<Value>>>
2515ExpandShapeOp::bubbleDownCasts(OpBuilder &builder) {
2516 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2517}
2518
2519/// Compute the layout map after collapsing a given source MemRef type with the
2520/// specified reassociation indices.
2521///
2522/// Note: All collapsed dims in a reassociation group must be contiguous. It is
2523/// not possible to check this by inspecting a MemRefType in the general case.
2524/// If non-contiguity cannot be checked statically, the collapse is assumed to
2525/// be valid (and thus accepted by this function) unless `strict = true`.
2526static FailureOr<StridedLayoutAttr>
2527computeCollapsedLayoutMap(MemRefType srcType,
2528 ArrayRef<ReassociationIndices> reassociation,
2529 bool strict = false) {
2530 int64_t srcOffset;
2531 SmallVector<int64_t> srcStrides;
2532 auto srcShape = srcType.getShape();
2533 if (failed(srcType.getStridesAndOffset(srcStrides, srcOffset)))
2534 return failure();
2535
2536 // The result stride of a reassociation group is the stride of the last entry
2537 // of the reassociation. (TODO: Should be the minimum stride in the
2538 // reassociation because strides are not necessarily sorted. E.g., when using
2539 // memref.transpose.) Dimensions of size 1 should be skipped, because their
2540 // strides are meaningless and could have any arbitrary value.
2541 SmallVector<int64_t> resultStrides;
2542 resultStrides.reserve(reassociation.size());
2543 for (const ReassociationIndices &reassoc : reassociation) {
2544 ArrayRef<int64_t> ref = llvm::ArrayRef(reassoc);
2545 while (srcShape[ref.back()] == 1 && ref.size() > 1)
2546 ref = ref.drop_back();
2547 if (ShapedType::isStatic(srcShape[ref.back()]) || ref.size() == 1) {
2548 resultStrides.push_back(srcStrides[ref.back()]);
2549 } else {
2550 // Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
2551 // the corresponding stride may have to be skipped. (See above comment.)
2552 // Therefore, the result stride cannot be statically determined and must
2553 // be dynamic.
2554 resultStrides.push_back(ShapedType::kDynamic);
2555 }
2556 }
2557
2558 // Validate that each reassociation group is contiguous.
2559 unsigned resultStrideIndex = resultStrides.size() - 1;
2560 for (const ReassociationIndices &reassoc : llvm::reverse(reassociation)) {
2561 auto trailingReassocs = ArrayRef<int64_t>(reassoc).drop_front();
2562 auto stride = SaturatedInteger::wrap(resultStrides[resultStrideIndex--]);
2563 for (int64_t idx : llvm::reverse(trailingReassocs)) {
2564 stride = stride * SaturatedInteger::wrap(srcShape[idx]);
2565
2566 // Both source and result stride must have the same static value. In that
2567 // case, we can be sure, that the dimensions are collapsible (because they
2568 // are contiguous).
2569 // If `strict = false` (default during op verification), we accept cases
2570 // where one or both strides are dynamic. This is best effort: We reject
2571 // ops where obviously non-contiguous dims are collapsed, but accept ops
2572 // where we cannot be sure statically. Such ops may fail at runtime. See
2573 // the op documentation for details.
2574 auto srcStride = SaturatedInteger::wrap(srcStrides[idx - 1]);
2575 if (strict && (stride.saturated || srcStride.saturated))
2576 return failure();
2577
2578 // Dimensions of size 1 should be skipped, because their strides are
2579 // meaningless and could have any arbitrary value.
2580 if (srcShape[idx - 1] == 1)
2581 continue;
2582
2583 if (!stride.saturated && !srcStride.saturated && stride != srcStride)
2584 return failure();
2585 }
2586 }
2587 return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
2588}
2589
2590bool CollapseShapeOp::isGuaranteedCollapsible(
2591 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2592 // MemRefs with identity layout are always collapsible.
2593 if (srcType.getLayout().isIdentity())
2594 return true;
2595
2596 return succeeded(computeCollapsedLayoutMap(srcType, reassociation,
2597 /*strict=*/true));
2598}
2599
2600MemRefType CollapseShapeOp::computeCollapsedType(
2601 MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
2602 SmallVector<int64_t> resultShape;
2603 resultShape.reserve(reassociation.size());
2604 for (const ReassociationIndices &group : reassociation) {
2605 auto groupSize = SaturatedInteger::wrap(1);
2606 for (int64_t srcDim : group)
2607 groupSize =
2608 groupSize * SaturatedInteger::wrap(srcType.getDimSize(srcDim));
2609 resultShape.push_back(groupSize.asInteger());
2610 }
2611
2612 if (srcType.getLayout().isIdentity()) {
2613 // If the source is contiguous (i.e., no layout map specified), so is the
2614 // result.
2615 MemRefLayoutAttrInterface layout;
2616 return MemRefType::get(resultShape, srcType.getElementType(), layout,
2617 srcType.getMemorySpace());
2618 }
2619
2620 // Source may not be fully contiguous. Compute the layout map.
2621 // Note: Dimensions that are collapsed into a single dim are assumed to be
2622 // contiguous.
2623 FailureOr<StridedLayoutAttr> computedLayout =
2624 computeCollapsedLayoutMap(srcType, reassociation);
2625 assert(succeeded(computedLayout) &&
2626 "invalid source layout map or collapsing non-contiguous dims");
2627 return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
2628 srcType.getMemorySpace());
2629}
2630
2631void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
2632 ArrayRef<ReassociationIndices> reassociation,
2633 ArrayRef<NamedAttribute> attrs) {
2634 auto srcType = llvm::cast<MemRefType>(src.getType());
2635 MemRefType resultType =
2636 CollapseShapeOp::computeCollapsedType(srcType, reassociation);
2638 getReassociationIndicesAttribute(b, reassociation));
2639 build(b, result, resultType, src, attrs);
2640}
2641
2642LogicalResult CollapseShapeOp::verify() {
2643 MemRefType srcType = getSrcType();
2644 MemRefType resultType = getResultType();
2645
2646 if (srcType.getRank() < resultType.getRank()) {
2647 auto r0 = srcType.getRank();
2648 auto r1 = resultType.getRank();
2649 return emitOpError("has source rank ")
2650 << r0 << " and result rank " << r1 << ". This is not a collapse ("
2651 << r0 << " < " << r1 << ").";
2652 }
2653
2654 // Verify result shape.
2655 if (failed(verifyCollapsedShape(getOperation(), resultType.getShape(),
2656 srcType.getShape(), getReassociationIndices(),
2657 /*allowMultipleDynamicDimsPerGroup=*/true)))
2658 return failure();
2659
2660 // Compute expected result type (including layout map).
2661 MemRefType expectedResultType;
2662 if (srcType.getLayout().isIdentity()) {
2663 // If the source is contiguous (i.e., no layout map specified), so is the
2664 // result.
2665 MemRefLayoutAttrInterface layout;
2666 expectedResultType =
2667 MemRefType::get(resultType.getShape(), srcType.getElementType(), layout,
2668 srcType.getMemorySpace());
2669 } else {
2670 // Source may not be fully contiguous. Compute the layout map.
2671 // Note: Dimensions that are collapsed into a single dim are assumed to be
2672 // contiguous.
2673 FailureOr<StridedLayoutAttr> computedLayout =
2674 computeCollapsedLayoutMap(srcType, getReassociationIndices());
2675 if (failed(computedLayout))
2676 return emitOpError(
2677 "invalid source layout map or collapsing non-contiguous dims");
2678 expectedResultType =
2679 MemRefType::get(resultType.getShape(), srcType.getElementType(),
2680 *computedLayout, srcType.getMemorySpace());
2681 }
2682
2683 if (expectedResultType != resultType)
2684 return emitOpError("expected collapsed type to be ")
2685 << expectedResultType << " but found " << resultType;
2686
2687 return success();
2688}
2689
2691 : public OpRewritePattern<CollapseShapeOp> {
2692public:
2693 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
2694
2695 LogicalResult matchAndRewrite(CollapseShapeOp op,
2696 PatternRewriter &rewriter) const override {
2697 auto cast = op.getOperand().getDefiningOp<CastOp>();
2698 if (!cast)
2699 return failure();
2700
2701 if (!CastOp::canFoldIntoConsumerOp(cast))
2702 return failure();
2703
2704 Type newResultType = CollapseShapeOp::computeCollapsedType(
2705 llvm::cast<MemRefType>(cast.getOperand().getType()),
2706 op.getReassociationIndices());
2707
2708 if (newResultType == op.getResultType()) {
2709 rewriter.modifyOpInPlace(
2710 op, [&]() { op.getSrcMutable().assign(cast.getSource()); });
2711 } else {
2712 Value newOp =
2713 CollapseShapeOp::create(rewriter, op->getLoc(), cast.getSource(),
2714 op.getReassociationIndices());
2715 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
2716 }
2717 return success();
2718 }
2719};
2720
2721void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2722 MLIRContext *context) {
2723 results.add<
2724 ComposeReassociativeReshapeOps<CollapseShapeOp, ReshapeOpKind::kCollapse>,
2725 ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp, CastOp,
2726 memref::DimOp, MemRefType>,
2727 CollapseShapeOpMemRefCastFolder>(context);
2728}
2729
2730OpFoldResult ExpandShapeOp::fold(FoldAdaptor adaptor) {
2732 adaptor.getOperands());
2733}
2734
2735OpFoldResult CollapseShapeOp::fold(FoldAdaptor adaptor) {
2737 adaptor.getOperands());
2738}
2739
2740FailureOr<std::optional<SmallVector<Value>>>
2741CollapseShapeOp::bubbleDownCasts(OpBuilder &builder) {
2742 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSrcMutable());
2743}
2744
2745//===----------------------------------------------------------------------===//
2746// ReshapeOp
2747//===----------------------------------------------------------------------===//
2748
2749void ReshapeOp::getAsmResultNames(
2750 function_ref<void(Value, StringRef)> setNameFn) {
2751 setNameFn(getResult(), "reshape");
2752}
2753
2754LogicalResult ReshapeOp::verify() {
2755 Type operandType = getSource().getType();
2756 Type resultType = getResult().getType();
2757
2758 Type operandElementType =
2759 llvm::cast<ShapedType>(operandType).getElementType();
2760 Type resultElementType = llvm::cast<ShapedType>(resultType).getElementType();
2761 if (operandElementType != resultElementType)
2762 return emitOpError("element types of source and destination memref "
2763 "types should be the same");
2764
2765 if (auto operandMemRefType = llvm::dyn_cast<MemRefType>(operandType))
2766 if (!operandMemRefType.getLayout().isIdentity())
2767 return emitOpError("source memref type should have identity affine map");
2768
2769 int64_t shapeSize =
2770 llvm::cast<MemRefType>(getShape().getType()).getDimSize(0);
2771 auto resultMemRefType = llvm::dyn_cast<MemRefType>(resultType);
2772 if (resultMemRefType) {
2773 if (!resultMemRefType.getLayout().isIdentity())
2774 return emitOpError("result memref type should have identity affine map");
2775 if (shapeSize == ShapedType::kDynamic)
2776 return emitOpError("cannot use shape operand with dynamic length to "
2777 "reshape to statically-ranked memref type");
2778 if (shapeSize != resultMemRefType.getRank())
2779 return emitOpError(
2780 "length of shape operand differs from the result's memref rank");
2781 }
2782 return success();
2783}
2784
2785FailureOr<std::optional<SmallVector<Value>>>
2786ReshapeOp::bubbleDownCasts(OpBuilder &builder) {
2787 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
2788}
2789
2790//===----------------------------------------------------------------------===//
2791// StoreOp
2792//===----------------------------------------------------------------------===//
2793
2794LogicalResult StoreOp::verify() {
2795 if (getNumOperands() != 2 + getMemRefType().getRank())
2796 return emitOpError("store index operand count not equal to memref rank");
2797
2798 return success();
2799}
2800
2801LogicalResult StoreOp::fold(FoldAdaptor adaptor,
2802 SmallVectorImpl<OpFoldResult> &results) {
2803 /// store(memrefcast) -> store
2804 return foldMemRefCast(*this, getValueToStore());
2805}
2806
2807FailureOr<std::optional<SmallVector<Value>>>
2808StoreOp::bubbleDownCasts(OpBuilder &builder) {
2810 ValueRange());
2811}
2812
2813//===----------------------------------------------------------------------===//
2814// SubViewOp
2815//===----------------------------------------------------------------------===//
2816
2817void SubViewOp::getAsmResultNames(
2818 function_ref<void(Value, StringRef)> setNameFn) {
2819 setNameFn(getResult(), "subview");
2820}
2821
2822/// A subview result type can be fully inferred from the source type and the
2823/// static representation of offsets, sizes and strides. Special sentinels
2824/// encode the dynamic case.
2825MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2826 ArrayRef<int64_t> staticOffsets,
2827 ArrayRef<int64_t> staticSizes,
2828 ArrayRef<int64_t> staticStrides) {
2829 unsigned rank = sourceMemRefType.getRank();
2830 (void)rank;
2831 assert(staticOffsets.size() == rank && "staticOffsets length mismatch");
2832 assert(staticSizes.size() == rank && "staticSizes length mismatch");
2833 assert(staticStrides.size() == rank && "staticStrides length mismatch");
2834
2835 // Extract source offset and strides.
2836 auto [sourceStrides, sourceOffset] = sourceMemRefType.getStridesAndOffset();
2837
2838 // Compute target offset whose value is:
2839 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
2840 int64_t targetOffset = sourceOffset;
2841 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
2842 auto staticOffset = std::get<0>(it), sourceStride = std::get<1>(it);
2843 targetOffset = (SaturatedInteger::wrap(targetOffset) +
2844 SaturatedInteger::wrap(staticOffset) *
2845 SaturatedInteger::wrap(sourceStride))
2846 .asInteger();
2847 }
2848
2849 // Compute target stride whose value is:
2850 // `sourceStrides_i * staticStrides_i`.
2851 SmallVector<int64_t, 4> targetStrides;
2852 targetStrides.reserve(staticOffsets.size());
2853 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
2854 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
2855 targetStrides.push_back((SaturatedInteger::wrap(sourceStride) *
2856 SaturatedInteger::wrap(staticStride))
2857 .asInteger());
2858 }
2859
2860 // The type is now known.
2861 return MemRefType::get(staticSizes, sourceMemRefType.getElementType(),
2862 StridedLayoutAttr::get(sourceMemRefType.getContext(),
2863 targetOffset, targetStrides),
2864 sourceMemRefType.getMemorySpace());
2865}
2866
2867MemRefType SubViewOp::inferResultType(MemRefType sourceMemRefType,
2868 ArrayRef<OpFoldResult> offsets,
2869 ArrayRef<OpFoldResult> sizes,
2870 ArrayRef<OpFoldResult> strides) {
2871 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2872 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2873 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2874 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2875 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2876 if (!hasValidSizesOffsets(staticOffsets))
2877 return {};
2878 if (!hasValidSizesOffsets(staticSizes))
2879 return {};
2880 if (!hasValidStrides(staticStrides))
2881 return {};
2882 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2883 staticSizes, staticStrides);
2884}
2885
2886MemRefType SubViewOp::inferRankReducedResultType(
2887 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2888 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2889 ArrayRef<int64_t> strides) {
2890 MemRefType inferredType =
2891 inferResultType(sourceRankedTensorType, offsets, sizes, strides);
2892 assert(inferredType.getRank() >= static_cast<int64_t>(resultShape.size()) &&
2893 "expected ");
2894 if (inferredType.getRank() == static_cast<int64_t>(resultShape.size()))
2895 return inferredType;
2896
2897 // Compute which dimensions are dropped.
2898 std::optional<llvm::SmallDenseSet<unsigned>> dimsToProject =
2899 computeRankReductionMask(inferredType.getShape(), resultShape);
2900 assert(dimsToProject.has_value() && "invalid rank reduction");
2901
2902 // Compute the layout and result type.
2903 auto inferredLayout = llvm::cast<StridedLayoutAttr>(inferredType.getLayout());
2904 SmallVector<int64_t> rankReducedStrides;
2905 rankReducedStrides.reserve(resultShape.size());
2906 for (auto [idx, value] : llvm::enumerate(inferredLayout.getStrides())) {
2907 if (!dimsToProject->contains(idx))
2908 rankReducedStrides.push_back(value);
2909 }
2910 return MemRefType::get(resultShape, inferredType.getElementType(),
2911 StridedLayoutAttr::get(inferredLayout.getContext(),
2912 inferredLayout.getOffset(),
2913 rankReducedStrides),
2914 inferredType.getMemorySpace());
2915}
2916
2917MemRefType SubViewOp::inferRankReducedResultType(
2918 ArrayRef<int64_t> resultShape, MemRefType sourceRankedTensorType,
2919 ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
2920 ArrayRef<OpFoldResult> strides) {
2921 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2922 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2923 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2924 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2925 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2926 return SubViewOp::inferRankReducedResultType(
2927 resultShape, sourceRankedTensorType, staticOffsets, staticSizes,
2928 staticStrides);
2929}
2930
2931// Build a SubViewOp with mixed static and dynamic entries and custom result
2932// type. If the type passed is nullptr, it is inferred.
2933void SubViewOp::build(OpBuilder &b, OperationState &result,
2934 MemRefType resultType, Value source,
2935 ArrayRef<OpFoldResult> offsets,
2936 ArrayRef<OpFoldResult> sizes,
2937 ArrayRef<OpFoldResult> strides,
2938 ArrayRef<NamedAttribute> attrs) {
2939 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
2940 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
2941 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
2942 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
2943 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
2944 auto sourceMemRefType = llvm::cast<MemRefType>(source.getType());
2945 // Structuring implementation this way avoids duplication between builders.
2946 if (!resultType) {
2947 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
2948 staticSizes, staticStrides);
2949 }
2950 result.addAttributes(attrs);
2951 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
2952 dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
2953 b.getDenseI64ArrayAttr(staticSizes),
2954 b.getDenseI64ArrayAttr(staticStrides));
2955}
2956
2957// Build a SubViewOp with mixed static and dynamic entries and inferred result
2958// type.
2959void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2960 ArrayRef<OpFoldResult> offsets,
2961 ArrayRef<OpFoldResult> sizes,
2962 ArrayRef<OpFoldResult> strides,
2963 ArrayRef<NamedAttribute> attrs) {
2964 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
2965}
2966
2967// Build a SubViewOp with static entries and inferred result type.
2968void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
2969 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2970 ArrayRef<int64_t> strides,
2971 ArrayRef<NamedAttribute> attrs) {
2972 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2973 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2974 return b.getI64IntegerAttr(v);
2975 }));
2976 SmallVector<OpFoldResult> sizeValues =
2977 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
2978 return b.getI64IntegerAttr(v);
2979 }));
2980 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
2981 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
2982 return b.getI64IntegerAttr(v);
2983 }));
2984 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
2985}
2986
2987// Build a SubViewOp with dynamic entries and custom result type. If the
2988// type passed is nullptr, it is inferred.
2989void SubViewOp::build(OpBuilder &b, OperationState &result,
2990 MemRefType resultType, Value source,
2991 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
2992 ArrayRef<int64_t> strides,
2993 ArrayRef<NamedAttribute> attrs) {
2994 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
2995 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
2996 return b.getI64IntegerAttr(v);
2997 }));
2998 SmallVector<OpFoldResult> sizeValues =
2999 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
3000 return b.getI64IntegerAttr(v);
3001 }));
3002 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3003 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
3004 return b.getI64IntegerAttr(v);
3005 }));
3006 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
3007 attrs);
3008}
3009
3010// Build a SubViewOp with dynamic entries and custom result type. If the type
3011// passed is nullptr, it is inferred.
3012void SubViewOp::build(OpBuilder &b, OperationState &result,
3013 MemRefType resultType, Value source, ValueRange offsets,
3014 ValueRange sizes, ValueRange strides,
3015 ArrayRef<NamedAttribute> attrs) {
3016 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
3017 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
3018 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
3019 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
3020 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
3021 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
3022 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
3023}
3024
3025// Build a SubViewOp with dynamic entries and inferred result type.
3026void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
3027 ValueRange offsets, ValueRange sizes, ValueRange strides,
3028 ArrayRef<NamedAttribute> attrs) {
3029 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
3030}
3031
3032/// For ViewLikeOpInterface.
3033Value SubViewOp::getViewSource() { return getSource(); }
3034
3035/// Return true if `t1` and `t2` have equal offsets (both dynamic or of same
3036/// static value).
3037static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2) {
3038 int64_t t1Offset, t2Offset;
3039 SmallVector<int64_t> t1Strides, t2Strides;
3040 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3041 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3042 return succeeded(res1) && succeeded(res2) && t1Offset == t2Offset;
3043}
3044
3045/// Return true if `t1` and `t2` have equal strides (both dynamic or of same
3046/// static value). Dimensions of `t1` may be dropped in `t2`; these must be
3047/// marked as dropped in `droppedDims`.
3048static bool haveCompatibleStrides(MemRefType t1, MemRefType t2,
3049 const llvm::SmallBitVector &droppedDims) {
3050 assert(size_t(t1.getRank()) == droppedDims.size() &&
3051 "incorrect number of bits");
3052 assert(size_t(t1.getRank() - t2.getRank()) == droppedDims.count() &&
3053 "incorrect number of dropped dims");
3054 int64_t t1Offset, t2Offset;
3055 SmallVector<int64_t> t1Strides, t2Strides;
3056 auto res1 = t1.getStridesAndOffset(t1Strides, t1Offset);
3057 auto res2 = t2.getStridesAndOffset(t2Strides, t2Offset);
3058 if (failed(res1) || failed(res2))
3059 return false;
3060 for (int64_t i = 0, j = 0, e = t1.getRank(); i < e; ++i) {
3061 if (droppedDims[i])
3062 continue;
3063 if (t1Strides[i] != t2Strides[j])
3064 return false;
3065 ++j;
3066 }
3067 return true;
3068}
3069
3071 SubViewOp op, Type expectedType) {
3072 auto memrefType = llvm::cast<ShapedType>(expectedType);
3073 switch (result) {
3075 return success();
3077 return op->emitError("expected result rank to be smaller or equal to ")
3078 << "the source rank, but got " << op.getType();
3080 return op->emitError("expected result type to be ")
3081 << expectedType
3082 << " or a rank-reduced version. (mismatch of result sizes), but got "
3083 << op.getType();
3085 return op->emitError("expected result element type to be ")
3086 << memrefType.getElementType() << ", but got " << op.getType();
3088 return op->emitError(
3089 "expected result and source memory spaces to match, but got ")
3090 << op.getType();
3092 return op->emitError("expected result type to be ")
3093 << expectedType
3094 << " or a rank-reduced version. (mismatch of result layout), but "
3095 "got "
3096 << op.getType();
3097 }
3098 llvm_unreachable("unexpected subview verification result");
3099}
3100
3101/// Verifier for SubViewOp.
3102LogicalResult SubViewOp::verify() {
3103 MemRefType baseType = getSourceType();
3104 MemRefType subViewType = getType();
3105 ArrayRef<int64_t> staticOffsets = getStaticOffsets();
3106 ArrayRef<int64_t> staticSizes = getStaticSizes();
3107 ArrayRef<int64_t> staticStrides = getStaticStrides();
3108
3109 // The base memref and the view memref should be in the same memory space.
3110 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
3111 return emitError("different memory spaces specified for base memref "
3112 "type ")
3113 << baseType << " and subview memref type " << subViewType;
3114
3115 // Verify that the base memref type has a strided layout map.
3116 if (!baseType.isStrided())
3117 return emitError("base type ") << baseType << " is not strided";
3118
3119 // Compute the expected result type, assuming that there are no rank
3120 // reductions.
3121 MemRefType expectedType = SubViewOp::inferResultType(
3122 baseType, staticOffsets, staticSizes, staticStrides);
3123
3124 // Verify all properties of a shaped type: rank, element type and dimension
3125 // sizes. This takes into account potential rank reductions.
3126 auto shapedTypeVerification = isRankReducedType(
3127 /*originalType=*/expectedType, /*candidateReducedType=*/subViewType);
3128 if (shapedTypeVerification != SliceVerificationResult::Success)
3129 return produceSubViewErrorMsg(shapedTypeVerification, *this, expectedType);
3130
3131 // Make sure that the memory space did not change.
3132 if (expectedType.getMemorySpace() != subViewType.getMemorySpace())
3134 *this, expectedType);
3135
3136 // Verify the offset of the layout map.
3137 if (!haveCompatibleOffsets(expectedType, subViewType))
3139 *this, expectedType);
3140
3141 // The only thing that's left to verify now are the strides. First, compute
3142 // the unused dimensions due to rank reductions. We have to look at sizes and
3143 // strides to decide which dimensions were dropped. This function also
3144 // partially verifies strides in case of rank reductions.
3145 auto unusedDims = computeMemRefRankReductionMask(expectedType, subViewType,
3146 getMixedSizes());
3147 if (failed(unusedDims))
3149 *this, expectedType);
3150
3151 // Strides must match.
3152 if (!haveCompatibleStrides(expectedType, subViewType, *unusedDims))
3154 *this, expectedType);
3155
3156 // Verify that offsets, sizes, strides do not run out-of-bounds with respect
3157 // to the base memref.
3158 SliceBoundsVerificationResult boundsResult =
3159 verifyInBoundsSlice(baseType.getShape(), staticOffsets, staticSizes,
3160 staticStrides, /*generateErrorMessage=*/true);
3161 if (!boundsResult.isValid)
3162 return getOperation()->emitError(boundsResult.errorMessage);
3163
3164 return success();
3165}
3166
3168 return os << "range " << range.offset << ":" << range.size << ":"
3169 << range.stride;
3170}
3171
3172/// Return the list of Range (i.e. offset, size, stride). Each Range
3173/// entry contains either the dynamic value or a ConstantIndexOp constructed
3174/// with `b` at location `loc`.
3175SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
3176 OpBuilder &b, Location loc) {
3177 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
3178 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
3179 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
3181 unsigned rank = ranks[0];
3182 res.reserve(rank);
3183 for (unsigned idx = 0; idx < rank; ++idx) {
3184 Value offset =
3185 op.isDynamicOffset(idx)
3186 ? op.getDynamicOffset(idx)
3187 : arith::ConstantIndexOp::create(b, loc, op.getStaticOffset(idx));
3188 Value size =
3189 op.isDynamicSize(idx)
3190 ? op.getDynamicSize(idx)
3191 : arith::ConstantIndexOp::create(b, loc, op.getStaticSize(idx));
3192 Value stride =
3193 op.isDynamicStride(idx)
3194 ? op.getDynamicStride(idx)
3195 : arith::ConstantIndexOp::create(b, loc, op.getStaticStride(idx));
3196 res.emplace_back(Range{offset, size, stride});
3197 }
3198 return res;
3199}
3200
3201/// Compute the canonical result type of a SubViewOp. Call `inferResultType`
3202/// to deduce the result type for the given `sourceType`. Additionally, reduce
3203/// the rank of the inferred result type if `currentResultType` is lower rank
3204/// than `currentSourceType`. Use this signature if `sourceType` is updated
3205/// together with the result type. In this case, it is important to compute
3206/// the dropped dimensions using `currentSourceType` whose strides align with
3207/// `currentResultType`.
3209 MemRefType currentResultType, MemRefType currentSourceType,
3210 MemRefType sourceType, ArrayRef<OpFoldResult> mixedOffsets,
3211 ArrayRef<OpFoldResult> mixedSizes, ArrayRef<OpFoldResult> mixedStrides) {
3212 MemRefType nonRankReducedType = SubViewOp::inferResultType(
3213 sourceType, mixedOffsets, mixedSizes, mixedStrides);
3214 FailureOr<llvm::SmallBitVector> unusedDims = computeMemRefRankReductionMask(
3215 currentSourceType, currentResultType, mixedSizes);
3216 if (failed(unusedDims))
3217 return nullptr;
3218
3219 auto layout = llvm::cast<StridedLayoutAttr>(nonRankReducedType.getLayout());
3220 SmallVector<int64_t> shape, strides;
3221 unsigned numDimsAfterReduction =
3222 nonRankReducedType.getRank() - unusedDims->count();
3223 shape.reserve(numDimsAfterReduction);
3224 strides.reserve(numDimsAfterReduction);
3225 for (const auto &[idx, size, stride] :
3226 llvm::zip(llvm::seq<unsigned>(0, nonRankReducedType.getRank()),
3227 nonRankReducedType.getShape(), layout.getStrides())) {
3228 if (unusedDims->test(idx))
3229 continue;
3230 shape.push_back(size);
3231 strides.push_back(stride);
3232 }
3233
3234 return MemRefType::get(shape, nonRankReducedType.getElementType(),
3235 StridedLayoutAttr::get(sourceType.getContext(),
3236 layout.getOffset(), strides),
3237 nonRankReducedType.getMemorySpace());
3238}
3239
3241 OpBuilder &b, Location loc, Value memref, ArrayRef<int64_t> targetShape) {
3242 auto memrefType = llvm::cast<MemRefType>(memref.getType());
3243 unsigned rank = memrefType.getRank();
3244 SmallVector<OpFoldResult> offsets(rank, b.getIndexAttr(0));
3246 SmallVector<OpFoldResult> strides(rank, b.getIndexAttr(1));
3247 MemRefType targetType = SubViewOp::inferRankReducedResultType(
3248 targetShape, memrefType, offsets, sizes, strides);
3249 return b.createOrFold<memref::SubViewOp>(loc, targetType, memref, offsets,
3250 sizes, strides);
3251}
3252
3253FailureOr<Value> SubViewOp::rankReduceIfNeeded(OpBuilder &b, Location loc,
3254 Value value,
3255 ArrayRef<int64_t> desiredShape) {
3256 auto sourceMemrefType = llvm::dyn_cast<MemRefType>(value.getType());
3257 assert(sourceMemrefType && "not a ranked memref type");
3258 auto sourceShape = sourceMemrefType.getShape();
3259 if (sourceShape.equals(desiredShape))
3260 return value;
3261 auto maybeRankReductionMask =
3262 mlir::computeRankReductionMask(sourceShape, desiredShape);
3263 if (!maybeRankReductionMask)
3264 return failure();
3265 return createCanonicalRankReducingSubViewOp(b, loc, value, desiredShape);
3266}
3267
3268/// Helper method to check if a `subview` operation is trivially a no-op. This
3269/// is the case if the all offsets are zero, all strides are 1, and the source
3270/// shape is same as the size of the subview. In such cases, the subview can
3271/// be folded into its source.
3272static bool isTrivialSubViewOp(SubViewOp subViewOp) {
3273 if (subViewOp.getSourceType().getRank() != subViewOp.getType().getRank())
3274 return false;
3275
3276 auto mixedOffsets = subViewOp.getMixedOffsets();
3277 auto mixedSizes = subViewOp.getMixedSizes();
3278 auto mixedStrides = subViewOp.getMixedStrides();
3279
3280 // Check offsets are zero.
3281 if (llvm::any_of(mixedOffsets, [](OpFoldResult ofr) {
3282 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3283 return !intValue || intValue.value() != 0;
3284 }))
3285 return false;
3286
3287 // Check strides are one.
3288 if (llvm::any_of(mixedStrides, [](OpFoldResult ofr) {
3289 std::optional<int64_t> intValue = getConstantIntValue(ofr);
3290 return !intValue || intValue.value() != 1;
3291 }))
3292 return false;
3293
3294 // Check all size values are static and matches the (static) source shape.
3295 ArrayRef<int64_t> sourceShape = subViewOp.getSourceType().getShape();
3296 for (const auto &size : llvm::enumerate(mixedSizes)) {
3297 std::optional<int64_t> intValue = getConstantIntValue(size.value());
3298 if (!intValue || *intValue != sourceShape[size.index()])
3299 return false;
3300 }
3301 // All conditions met. The `SubViewOp` is foldable as a no-op.
3302 return true;
3303}
3304
3305namespace {
3306/// Pattern to rewrite a subview op with MemRefCast arguments.
3307/// This essentially pushes memref.cast past its consuming subview when
3308/// `canFoldIntoConsumerOp` is true.
3309///
3310/// Example:
3311/// ```
3312/// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
3313/// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
3314/// memref<?x?xf32> to memref<3x4xf32, strided<[?, 1], offset: ?>>
3315/// ```
3316/// is rewritten into:
3317/// ```
3318/// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
3319/// %1 = memref.cast %0: memref<3x4xf32, strided<[16, 1], offset: 0>> to
3320/// memref<3x4xf32, strided<[?, 1], offset: ?>>
3321/// ```
3322class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
3323public:
3324 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3325
3326 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3327 PatternRewriter &rewriter) const override {
3328 // Any constant operand, just return to let SubViewOpConstantFolder kick
3329 // in.
3330 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
3331 return matchPattern(operand, matchConstantIndex());
3332 }))
3333 return failure();
3334
3335 auto castOp = subViewOp.getSource().getDefiningOp<CastOp>();
3336 if (!castOp)
3337 return failure();
3338
3339 if (!CastOp::canFoldIntoConsumerOp(castOp))
3340 return failure();
3341
3342 // Compute the SubViewOp result type after folding the MemRefCastOp. Use
3343 // the MemRefCastOp source operand type to infer the result type and the
3344 // current SubViewOp source operand type to compute the dropped dimensions
3345 // if the operation is rank-reducing.
3346 auto resultType = getCanonicalSubViewResultType(
3347 subViewOp.getType(), subViewOp.getSourceType(),
3348 llvm::cast<MemRefType>(castOp.getSource().getType()),
3349 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
3350 subViewOp.getMixedStrides());
3351 if (!resultType)
3352 return failure();
3353
3354 Value newSubView = SubViewOp::create(
3355 rewriter, subViewOp.getLoc(), resultType, castOp.getSource(),
3356 subViewOp.getOffsets(), subViewOp.getSizes(), subViewOp.getStrides(),
3357 subViewOp.getStaticOffsets(), subViewOp.getStaticSizes(),
3358 subViewOp.getStaticStrides());
3359 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3360 newSubView);
3361 return success();
3362 }
3363};
3364
3365/// Canonicalize subview ops that are no-ops. When the source shape is not
3366/// same as a result shape due to use of `affine_map`.
3367class TrivialSubViewOpFolder final : public OpRewritePattern<SubViewOp> {
3368public:
3369 using OpRewritePattern<SubViewOp>::OpRewritePattern;
3370
3371 LogicalResult matchAndRewrite(SubViewOp subViewOp,
3372 PatternRewriter &rewriter) const override {
3373 if (!isTrivialSubViewOp(subViewOp))
3374 return failure();
3375 if (subViewOp.getSourceType() == subViewOp.getType()) {
3376 rewriter.replaceOp(subViewOp, subViewOp.getSource());
3377 return success();
3378 }
3379 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
3380 subViewOp.getSource());
3381 return success();
3382 }
3383};
3384} // namespace
3385
3386/// Return the canonical type of the result of a subview.
3388 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
3389 ArrayRef<OpFoldResult> mixedSizes,
3390 ArrayRef<OpFoldResult> mixedStrides) {
3391 // Infer a memref type without taking into account any rank reductions.
3392 MemRefType resTy = SubViewOp::inferResultType(
3393 op.getSourceType(), mixedOffsets, mixedSizes, mixedStrides);
3394 if (!resTy)
3395 return {};
3396 MemRefType nonReducedType = resTy;
3397
3398 // Directly return the non-rank reduced type if there are no dropped dims.
3399 llvm::SmallBitVector droppedDims = op.getDroppedDims();
3400 if (droppedDims.none())
3401 return nonReducedType;
3402
3403 // Take the strides and offset from the non-rank reduced type.
3404 auto [nonReducedStrides, offset] = nonReducedType.getStridesAndOffset();
3405
3406 // Drop dims from shape and strides.
3407 SmallVector<int64_t> targetShape;
3408 SmallVector<int64_t> targetStrides;
3409 for (int64_t i = 0; i < static_cast<int64_t>(mixedSizes.size()); ++i) {
3410 if (droppedDims.test(i))
3411 continue;
3412 targetStrides.push_back(nonReducedStrides[i]);
3413 targetShape.push_back(nonReducedType.getDimSize(i));
3414 }
3415
3416 return MemRefType::get(targetShape, nonReducedType.getElementType(),
3417 StridedLayoutAttr::get(nonReducedType.getContext(),
3418 offset, targetStrides),
3419 nonReducedType.getMemorySpace());
3420 }
3421};
3422
3423/// A canonicalizer wrapper to replace SubViewOps.
3425 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
3426 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
3427 }
3428};
3429
3430void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3431 MLIRContext *context) {
3432 results
3433 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
3434 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
3435 SubViewOpMemRefCastFolder, TrivialSubViewOpFolder>(context);
3436}
3437
3438OpFoldResult SubViewOp::fold(FoldAdaptor adaptor) {
3439 MemRefType sourceMemrefType = getSource().getType();
3440 MemRefType resultMemrefType = getResult().getType();
3441 auto resultLayout =
3442 dyn_cast_if_present<StridedLayoutAttr>(resultMemrefType.getLayout());
3443
3444 if (resultMemrefType == sourceMemrefType &&
3445 resultMemrefType.hasStaticShape() &&
3446 (!resultLayout || resultLayout.hasStaticLayout())) {
3447 return getViewSource();
3448 }
3449
3450 // Fold subview(subview(x)), where both subviews have the same size and the
3451 // second subview's offsets are all zero. (I.e., the second subview is a
3452 // no-op.)
3453 if (auto srcSubview = getViewSource().getDefiningOp<SubViewOp>()) {
3454 auto srcSizes = srcSubview.getMixedSizes();
3455 auto sizes = getMixedSizes();
3456 auto offsets = getMixedOffsets();
3457 bool allOffsetsZero = llvm::all_of(offsets, isZeroInteger);
3458 auto strides = getMixedStrides();
3459 bool allStridesOne = llvm::all_of(strides, isOneInteger);
3460 bool allSizesSame = llvm::equal(sizes, srcSizes);
3461 if (allOffsetsZero && allStridesOne && allSizesSame &&
3462 resultMemrefType == sourceMemrefType)
3463 return getViewSource();
3464 }
3465
3466 return {};
3467}
3468
3469FailureOr<std::optional<SmallVector<Value>>>
3470SubViewOp::bubbleDownCasts(OpBuilder &builder) {
3471 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3472}
3473
3474void SubViewOp::inferStridedMetadataRanges(
3475 ArrayRef<StridedMetadataRange> ranges, GetIntRangeFn getIntRange,
3476 SetStridedMetadataRangeFn setMetadata, int32_t indexBitwidth) {
3477 auto isUninitialized =
3478 +[](IntegerValueRange range) { return range.isUninitialized(); };
3479
3480 // Bail early if any of the operands metadata is not ready:
3481 SmallVector<IntegerValueRange> offsetOperands =
3482 getIntValueRanges(getMixedOffsets(), getIntRange, indexBitwidth);
3483 if (llvm::any_of(offsetOperands, isUninitialized))
3484 return;
3485
3486 SmallVector<IntegerValueRange> sizeOperands =
3487 getIntValueRanges(getMixedSizes(), getIntRange, indexBitwidth);
3488 if (llvm::any_of(sizeOperands, isUninitialized))
3489 return;
3490
3491 SmallVector<IntegerValueRange> stridesOperands =
3492 getIntValueRanges(getMixedStrides(), getIntRange, indexBitwidth);
3493 if (llvm::any_of(stridesOperands, isUninitialized))
3494 return;
3495
3496 StridedMetadataRange sourceRange =
3497 ranges[getSourceMutable().getOperandNumber()];
3498 if (sourceRange.isUninitialized())
3499 return;
3500
3501 ArrayRef<ConstantIntRanges> srcStrides = sourceRange.getStrides();
3502
3503 // Get the dropped dims.
3504 llvm::SmallBitVector droppedDims = getDroppedDims();
3505
3506 // Compute the new offset, strides and sizes.
3507 ConstantIntRanges offset = sourceRange.getOffsets()[0];
3508 SmallVector<ConstantIntRanges> strides, sizes;
3509
3510 for (size_t i = 0, e = droppedDims.size(); i < e; ++i) {
3511 bool dropped = droppedDims.test(i);
3512 // Compute the new offset.
3513 ConstantIntRanges off =
3514 intrange::inferMul({offsetOperands[i].getValue(), srcStrides[i]});
3515 offset = intrange::inferAdd({offset, off});
3516
3517 // Skip dropped dimensions.
3518 if (dropped)
3519 continue;
3520 // Multiply the strides.
3521 strides.push_back(
3522 intrange::inferMul({stridesOperands[i].getValue(), srcStrides[i]}));
3523 // Get the sizes.
3524 sizes.push_back(sizeOperands[i].getValue());
3525 }
3526
3527 setMetadata(getResult(),
3529 SmallVector<ConstantIntRanges>({std::move(offset)}),
3530 std::move(sizes), std::move(strides)));
3531}
3532
3533//===----------------------------------------------------------------------===//
3534// TransposeOp
3535//===----------------------------------------------------------------------===//
3536
3537void TransposeOp::getAsmResultNames(
3538 function_ref<void(Value, StringRef)> setNameFn) {
3539 setNameFn(getResult(), "transpose");
3540}
3541
3542/// Build a strided memref type by applying `permutationMap` to `memRefType`.
3543static MemRefType inferTransposeResultType(MemRefType memRefType,
3544 AffineMap permutationMap) {
3545 auto originalSizes = memRefType.getShape();
3546 auto [originalStrides, offset] = memRefType.getStridesAndOffset();
3547 assert(originalStrides.size() == static_cast<unsigned>(memRefType.getRank()));
3548
3549 // Compute permuted sizes and strides.
3550 auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3551 auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
3552
3553 return MemRefType::Builder(memRefType)
3554 .setShape(sizes)
3555 .setLayout(
3556 StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
3557}
3558
3559void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
3560 AffineMapAttr permutation,
3561 ArrayRef<NamedAttribute> attrs) {
3562 auto permutationMap = permutation.getValue();
3563 assert(permutationMap);
3564
3565 auto memRefType = llvm::cast<MemRefType>(in.getType());
3566 // Compute result type.
3567 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
3568
3569 result.addAttribute(TransposeOp::getPermutationAttrStrName(), permutation);
3570 build(b, result, resultType, in, attrs);
3571}
3572
3573// transpose $in $permutation attr-dict : type($in) `to` type(results)
3574void TransposeOp::print(OpAsmPrinter &p) {
3575 p << " " << getIn() << " " << getPermutation();
3576 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrStrName()});
3577 p << " : " << getIn().getType() << " to " << getType();
3578}
3579
3580ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
3581 OpAsmParser::UnresolvedOperand in;
3582 AffineMap permutation;
3583 MemRefType srcType, dstType;
3584 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
3585 parser.parseOptionalAttrDict(result.attributes) ||
3586 parser.parseColonType(srcType) ||
3587 parser.resolveOperand(in, srcType, result.operands) ||
3588 parser.parseKeywordType("to", dstType) ||
3589 parser.addTypeToList(dstType, result.types))
3590 return failure();
3591
3592 result.addAttribute(TransposeOp::getPermutationAttrStrName(),
3593 AffineMapAttr::get(permutation));
3594 return success();
3595}
3596
3597LogicalResult TransposeOp::verify() {
3598 if (!getPermutation().isPermutation())
3599 return emitOpError("expected a permutation map");
3600 if (getPermutation().getNumDims() != getIn().getType().getRank())
3601 return emitOpError("expected a permutation map of same rank as the input");
3602
3603 auto srcType = llvm::cast<MemRefType>(getIn().getType());
3604 auto resultType = llvm::cast<MemRefType>(getType());
3605 auto canonicalResultType = inferTransposeResultType(srcType, getPermutation())
3606 .canonicalizeStridedLayout();
3607
3608 if (resultType.canonicalizeStridedLayout() != canonicalResultType)
3609 return emitOpError("result type ")
3610 << resultType
3611 << " is not equivalent to the canonical transposed input type "
3612 << canonicalResultType;
3613 return success();
3614}
3615
3616OpFoldResult TransposeOp::fold(FoldAdaptor) {
3617 // First check for identity permutation, we can fold it away if input and
3618 // result types are identical already.
3619 if (getPermutation().isIdentity() && getType() == getIn().getType())
3620 return getIn();
3621 // Fold two consecutive memref.transpose Ops into one by composing their
3622 // permutation maps.
3623 if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3624 AffineMap composedPermutation =
3625 getPermutation().compose(otherTransposeOp.getPermutation());
3626 getInMutable().assign(otherTransposeOp.getIn());
3627 setPermutation(composedPermutation);
3628 return getResult();
3629 }
3630 return {};
3631}
3632
3633FailureOr<std::optional<SmallVector<Value>>>
3634TransposeOp::bubbleDownCasts(OpBuilder &builder) {
3635 return bubbleDownCastsPassthroughOpImpl(*this, builder, getInMutable());
3636}
3637
3638//===----------------------------------------------------------------------===//
3639// ViewOp
3640//===----------------------------------------------------------------------===//
3641
3642void ViewOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
3643 setNameFn(getResult(), "view");
3644}
3645
3646LogicalResult ViewOp::verify() {
3647 auto baseType = llvm::cast<MemRefType>(getOperand(0).getType());
3648 auto viewType = getType();
3649
3650 // The base memref should have identity layout map (or none).
3651 if (!baseType.getLayout().isIdentity())
3652 return emitError("unsupported map for base memref type ") << baseType;
3653
3654 // The result memref should have identity layout map (or none).
3655 if (!viewType.getLayout().isIdentity())
3656 return emitError("unsupported map for result memref type ") << viewType;
3657
3658 // The base memref and the view memref should be in the same memory space.
3659 if (baseType.getMemorySpace() != viewType.getMemorySpace())
3660 return emitError("different memory spaces specified for base memref "
3661 "type ")
3662 << baseType << " and view memref type " << viewType;
3663
3664 // Verify that we have the correct number of sizes for the result type.
3665 unsigned numDynamicDims = viewType.getNumDynamicDims();
3666 if (getSizes().size() != numDynamicDims)
3667 return emitError("incorrect number of size operands for type ") << viewType;
3668
3669 return success();
3670}
3671
3672Value ViewOp::getViewSource() { return getSource(); }
3673
3674OpFoldResult ViewOp::fold(FoldAdaptor adaptor) {
3675 MemRefType sourceMemrefType = getSource().getType();
3676 MemRefType resultMemrefType = getResult().getType();
3677
3678 if (resultMemrefType == sourceMemrefType && resultMemrefType.hasStaticShape())
3679 return getViewSource();
3680
3681 return {};
3682}
3683
3684namespace {
3685
3686struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
3687 using OpRewritePattern<ViewOp>::OpRewritePattern;
3688
3689 LogicalResult matchAndRewrite(ViewOp viewOp,
3690 PatternRewriter &rewriter) const override {
3691 // Return if none of the operands are constants.
3692 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
3693 return matchPattern(operand, matchConstantIndex());
3694 }))
3695 return failure();
3696
3697 // Get result memref type.
3698 auto memrefType = viewOp.getType();
3699
3700 // Get offset from old memref view type 'memRefType'.
3701 int64_t oldOffset;
3702 SmallVector<int64_t, 4> oldStrides;
3703 if (failed(memrefType.getStridesAndOffset(oldStrides, oldOffset)))
3704 return failure();
3705 assert(oldOffset == 0 && "Expected 0 offset");
3706
3707 SmallVector<Value, 4> newOperands;
3708
3709 // Offset cannot be folded into result type.
3710
3711 // Fold any dynamic dim operands which are produced by a constant.
3712 SmallVector<int64_t, 4> newShapeConstants;
3713 newShapeConstants.reserve(memrefType.getRank());
3714
3715 unsigned dynamicDimPos = 0;
3716 unsigned rank = memrefType.getRank();
3717 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
3718 int64_t dimSize = memrefType.getDimSize(dim);
3719 // If this is already static dimension, keep it.
3720 if (ShapedType::isStatic(dimSize)) {
3721 newShapeConstants.push_back(dimSize);
3722 continue;
3723 }
3724 auto *defOp = viewOp.getSizes()[dynamicDimPos].getDefiningOp();
3725 if (auto constantIndexOp =
3726 dyn_cast_or_null<arith::ConstantIndexOp>(defOp)) {
3727 // Dynamic shape dimension will be folded.
3728 newShapeConstants.push_back(constantIndexOp.value());
3729 } else {
3730 // Dynamic shape dimension not folded; copy operand from old memref.
3731 newShapeConstants.push_back(dimSize);
3732 newOperands.push_back(viewOp.getSizes()[dynamicDimPos]);
3733 }
3734 dynamicDimPos++;
3735 }
3736
3737 // Create new memref type with constant folded dims.
3738 MemRefType newMemRefType =
3739 MemRefType::Builder(memrefType).setShape(newShapeConstants);
3740 // Nothing new, don't fold.
3741 if (newMemRefType == memrefType)
3742 return failure();
3743
3744 // Create new ViewOp.
3745 auto newViewOp = ViewOp::create(rewriter, viewOp.getLoc(), newMemRefType,
3746 viewOp.getOperand(0), viewOp.getByteShift(),
3747 newOperands);
3748 // Insert a cast so we have the same type as the old memref type.
3749 rewriter.replaceOpWithNewOp<CastOp>(viewOp, viewOp.getType(), newViewOp);
3750 return success();
3751 }
3752};
3753
3754struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
3755 using OpRewritePattern<ViewOp>::OpRewritePattern;
3756
3757 LogicalResult matchAndRewrite(ViewOp viewOp,
3758 PatternRewriter &rewriter) const override {
3759 Value memrefOperand = viewOp.getOperand(0);
3760 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
3761 if (!memrefCastOp)
3762 return failure();
3763 Value allocOperand = memrefCastOp.getOperand();
3764 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
3765 if (!allocOp)
3766 return failure();
3767 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
3768 viewOp.getByteShift(),
3769 viewOp.getSizes());
3770 return success();
3771 }
3772};
3773
3774} // namespace
3775
3776void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
3777 MLIRContext *context) {
3778 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
3779}
3780
3781FailureOr<std::optional<SmallVector<Value>>>
3782ViewOp::bubbleDownCasts(OpBuilder &builder) {
3783 return bubbleDownCastsPassthroughOpImpl(*this, builder, getSourceMutable());
3784}
3785
3786//===----------------------------------------------------------------------===//
3787// AtomicRMWOp
3788//===----------------------------------------------------------------------===//
3789
3790LogicalResult AtomicRMWOp::verify() {
3791 if (getMemRefType().getRank() != getNumOperands() - 2)
3792 return emitOpError(
3793 "expects the number of subscripts to be equal to memref rank");
3794 switch (getKind()) {
3795 case arith::AtomicRMWKind::addf:
3796 case arith::AtomicRMWKind::maximumf:
3797 case arith::AtomicRMWKind::minimumf:
3798 case arith::AtomicRMWKind::mulf:
3799 if (!llvm::isa<FloatType>(getValue().getType()))
3800 return emitOpError() << "with kind '"
3801 << arith::stringifyAtomicRMWKind(getKind())
3802 << "' expects a floating-point type";
3803 break;
3804 case arith::AtomicRMWKind::addi:
3805 case arith::AtomicRMWKind::maxs:
3806 case arith::AtomicRMWKind::maxu:
3807 case arith::AtomicRMWKind::mins:
3808 case arith::AtomicRMWKind::minu:
3809 case arith::AtomicRMWKind::muli:
3810 case arith::AtomicRMWKind::ori:
3811 case arith::AtomicRMWKind::xori:
3812 case arith::AtomicRMWKind::andi:
3813 if (!llvm::isa<IntegerType>(getValue().getType()))
3814 return emitOpError() << "with kind '"
3815 << arith::stringifyAtomicRMWKind(getKind())
3816 << "' expects an integer type";
3817 break;
3818 default:
3819 break;
3820 }
3821 return success();
3822}
3823
3824OpFoldResult AtomicRMWOp::fold(FoldAdaptor adaptor) {
3825 /// atomicrmw(memrefcast) -> atomicrmw
3826 if (succeeded(foldMemRefCast(*this, getValue())))
3827 return getResult();
3828 return OpFoldResult();
3829}
3830
3831FailureOr<std::optional<SmallVector<Value>>>
3832AtomicRMWOp::bubbleDownCasts(OpBuilder &builder) {
3834 getResult());
3835}
3836
3837//===----------------------------------------------------------------------===//
3838// TableGen'd op method definitions
3839//===----------------------------------------------------------------------===//
3840
3841#define GET_OP_CLASSES
3842#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static bool hasSideEffects(Operation *op)
static bool isPermutation(const std::vector< PermutationTy > &permutation)
Definition IRAffine.cpp:67
static Type getElementType(Type type)
Determine the element type of type.
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
auto load
static void constifyIndexValues(SmallVectorImpl< OpFoldResult > &values, ArrayRef< int64_t > constValues)
Helper function that sets values[i] to constValues[i] if the latter is a static value,...
Definition MemRefOps.cpp:96
static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op, TypeAttr type, Attribute initialValue)
static LogicalResult verifyCollapsedShape(Operation *op, ArrayRef< int64_t > collapsedShape, ArrayRef< int64_t > expandedShape, ArrayRef< ReassociationIndices > reassociation, bool allowMultipleDynamicDimsPerGroup)
Helper function for verifying the shape of ExpandShapeOp and ResultShapeOp result and operand.
static bool isOpItselfPotentialAutomaticAllocation(Operation *op)
Given an operation, return whether this op itself could allocate an AutomaticAllocationScopeResource.
static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap)
Build a strided memref type by applying permutationMap to memRefType.
static bool isGuaranteedAutomaticAllocation(Operation *op)
Given an operation, return whether this op is guaranteed to allocate an AutomaticAllocationScopeResou...
static FailureOr< StridedLayoutAttr > computeExpandedLayoutMap(MemRefType srcType, ArrayRef< int64_t > resultShape, ArrayRef< ReassociationIndices > reassociation)
Compute the layout map after expanding a given source MemRef type with the specified reassociation in...
static bool haveCompatibleOffsets(MemRefType t1, MemRefType t2)
Return true if t1 and t2 have equal offsets (both dynamic or of same static value).
static LogicalResult FoldCopyOfCast(CopyOp op)
If the source/target of a CopyOp is a CastOp that does not modify the shape and element type,...
static bool replaceConstantUsesOf(OpBuilder &rewriter, Location loc, Container values, ArrayRef< OpFoldResult > maybeConstants)
Helper function to perform the replacement of all constant uses of values by a materialized constant ...
static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, SubViewOp op, Type expectedType)
static MemRefType getCanonicalSubViewResultType(MemRefType currentResultType, MemRefType currentSourceType, MemRefType sourceType, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
Compute the canonical result type of a SubViewOp.
static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue)
static std::tuple< MemorySpaceCastOpInterface, PtrLikeTypeInterface, Type > getMemorySpaceCastInfo(BaseMemRefType resultTy, Value src)
Helper function to retrieve a lossless memory-space cast, and the corresponding new result memref typ...
static FailureOr< llvm::SmallBitVector > computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType, ArrayRef< OpFoldResult > sizes)
Given the originalType and a candidateReducedType whose shape is assumed to be a subset of originalTy...
static bool isTrivialSubViewOp(SubViewOp subViewOp)
Helper method to check if a subview operation is trivially a no-op.
static bool lastNonTerminatorInRegion(Operation *op)
Return whether this op is the last non terminating op in a region.
static std::map< int64_t, unsigned > getNumOccurences(ArrayRef< int64_t > vals)
Return a map with key being elements in vals and data being number of occurences of it.
static bool haveCompatibleStrides(MemRefType t1, MemRefType t2, const llvm::SmallBitVector &droppedDims)
Return true if t1 and t2 have equal strides (both dynamic or of same static value).
static FailureOr< StridedLayoutAttr > computeCollapsedLayoutMap(MemRefType srcType, ArrayRef< ReassociationIndices > reassociation, bool strict=false)
Compute the layout map after collapsing a given source MemRef type with the specified reassociation i...
static FailureOr< std::optional< SmallVector< Value > > > bubbleDownCastsPassthroughOpImpl(ConcreteOpTy op, OpBuilder &builder, OpOperand &src)
Implementation of bubbleDownCasts method for memref operations that return a single memref result.
static LogicalResult verifyAllocLikeOp(AllocLikeOp op)
static llvm::SmallBitVector getDroppedDims(ArrayRef< int64_t > reducedShape, ArrayRef< OpFoldResult > mixedSizes)
Compute the dropped dimensions of a rank-reducing tensor.extract_slice op or rank-extending tensor....
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
@ Square
Square brackets surrounding zero or more operands.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseAffineMap(AffineMap &map)=0
Parse an affine map instance into 'map'.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseGreater()=0
Parse a '>' token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalArrowTypeList(SmallVectorImpl< Type > &result)=0
Parse an optional arrow followed by a type list.
ParseResult parseKeywordType(const char *keyword, Type &result)
Parse a keyword followed by a type.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printAttributeWithoutType(Attribute attr)
Print the given attribute without its type.
Attributes are known-constant values of operations.
Definition Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
ArrayRef< int64_t > getShape() const
Returns the shape of this memref type.
FailureOr< PtrLikeTypeInterface > clonePtrWith(Attribute memorySpace, std::optional< Type > elementType) const
Clone this type with the given memory space and element type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Block represents an ordered list of Operations.
Definition Block.h:33
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:250
This class is a general helper class for creating context-global objects like types,...
Definition Builders.h:51
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IndexType getIndexType()
Definition Builders.cpp:51
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Builder & setShape(ArrayRef< int64_t > newShape)
Builder & setLayout(MemRefLayoutAttrInterface newLayout)
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
ParseResult parseTrailingOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None)
Parse zero or more trailing SSA comma-separated trailing operand references with a specified surround...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
void printOperands(const ContainerType &container)
Print a comma separated list of operands.
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition Builders.h:412
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
A trait of region holding operations that define a new scope for automatic allocations,...
This trait indicates that the memory effects of an operation includes the effects of operations neste...
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
void replaceUsesOfWith(Value from, Value to)
Replace any uses of 'from' with 'to' within this operation.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
MutableArrayRef< OpOperand > getOpOperands()
Definition Operation.h:383
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Region.h:98
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static StridedMetadataRange getRanked(SmallVectorImpl< ConstantIntRanges > &&offsets, SmallVectorImpl< ConstantIntRanges > &&sizes, SmallVectorImpl< ConstantIntRanges > &&strides)
Returns a ranked strided metadata range.
ArrayRef< ConstantIntRanges > getStrides() const
Get the strides ranges.
bool isUninitialized() const
Returns whether the metadata is uninitialized.
ArrayRef< ConstantIntRanges > getOffsets() const
Get the offsets range.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
MLIRContext * getContext() const
Return the MLIRContext in which this type was uniqued.
Definition Types.cpp:35
bool isIndex() const
Definition Types.cpp:54
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
Speculatability
This enum is returned from the getSpeculatability method in the ConditionallySpeculatable op interfac...
constexpr auto Speculatable
constexpr auto NotSpeculatable
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition Matchers.h:344
FailureOr< std::optional< SmallVector< Value > > > bubbleDownInPlaceMemorySpaceCastImpl(OpOperand &operand, ValueRange results)
Tries to bubble-down inplace a MemorySpaceCastOpInterface operation referenced by operand.
ConstantIntRanges inferAdd(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferMul(ArrayRef< ConstantIntRanges > argRanges, OverflowFlags ovfFlags=OverflowFlags::None)
ConstantIntRanges inferShapedDimOpInterface(ShapedDimOpInterface op, const IntegerValueRange &maybeDim)
Returns the integer range for the result of a ShapedDimOpInterface given the optional inferred ranges...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:60
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given memref value.
Definition MemRefOps.cpp:68
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition MemRefOps.cpp:77
Value createCanonicalRankReducingSubViewOp(OpBuilder &b, Location loc, Value memref, ArrayRef< int64_t > targetShape)
Create a rank-reducing SubViewOp @[0 .
Operation::operand_range getIndices(Operation *op)
Get the indices that the given load/store operation is operating on.
Definition Utils.cpp:18
DynamicAPInt getIndex(const ConeV &cone)
Get the index of a cone, i.e., the volume of the parallelepiped spanned by its generators,...
Definition Barvinok.cpp:63
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Value constantIndex(OpBuilder &builder, Location loc, int64_t i)
Generates a constant of index type.
MemRefType getMemRefType(T &&t)
Convenience method to abbreviate casting getType().
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
SmallVector< OpFoldResult > getMixedValues(ArrayRef< int64_t > staticValues, ValueRange dynamicValues, MLIRContext *context)
Return a vector of OpFoldResults with the same size a staticValues, but all elements for which Shaped...
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
SliceVerificationResult
Enum that captures information related to verifier error conditions on slice insert/extract type of o...
constexpr StringRef getReassociationAttrName()
Attribute name for the ArrayAttr which encodes reassociation indices.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
raw_ostream & operator<<(raw_ostream &os, const AliasResult &result)
llvm::function_ref< void(Value, const IntegerValueRange &)> SetIntLatticeFn
Similar to SetIntRangeFn, but operating on IntegerValueRange lattice values.
static OpFoldResult foldReshapeOp(ReshapeOpTy reshapeOp, ArrayRef< Attribute > operands)
SliceBoundsVerificationResult verifyInBoundsSlice(ArrayRef< int64_t > shape, ArrayRef< int64_t > staticOffsets, ArrayRef< int64_t > staticSizes, ArrayRef< int64_t > staticStrides, bool generateErrorMessage=false)
Verify that the offsets/sizes/strides-style access into the given shape is in-bounds.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
SmallVector< Range, 8 > getOrCreateRanges(OffsetSizeAndStrideOpInterface op, OpBuilder &b, Location loc)
Return the list of Range (i.e.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< AffineMap, 4 > getSymbolLessAffineMaps(ArrayRef< ReassociationExprs > reassociation)
Constructs affine maps out of Array<Array<AffineExpr>>.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool hasValidSizesOffsets(SmallVector< int64_t > sizesOrOffsets)
Helper function to check whether the passed in sizes or offsets are valid.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
SmallVector< IntegerValueRange > getIntValueRanges(ArrayRef< OpFoldResult > values, GetIntRangeFn getIntRange, int32_t indexBitwidth)
Helper function to collect the integer range values of an array of op fold results.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
bool hasValidStrides(SmallVector< int64_t > strides)
Helper function to check whether the passed in strides are valid.
void dispatchIndexOpFoldResults(ArrayRef< OpFoldResult > ofrs, SmallVectorImpl< Value > &dynamicVec, SmallVectorImpl< int64_t > &staticVec)
Helper function to dispatch multiple OpFoldResults according to the behavior of dispatchIndexOpFoldRe...
SmallVector< SmallVector< AffineExpr, 2 >, 2 > convertReassociationIndicesToExprs(MLIRContext *context, ArrayRef< ReassociationIndices > reassociationIndices)
Convert reassociation indices to affine expressions.
std::optional< SmallVector< OpFoldResult > > inferExpandShapeOutputShape(OpBuilder &b, Location loc, ShapedType expandedType, ArrayRef< ReassociationIndices > reassociation, ArrayRef< OpFoldResult > inputShape)
Infer the output shape for a {memref|tensor}.expand_shape when it is possible to do so.
Definition Utils.cpp:23
SmallVector< T > applyPermutationMap(AffineMap map, llvm::ArrayRef< T > source)
Apply a permutation from map to source and return the result.
Definition AffineMap.h:675
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
function_ref< void(Value, const StridedMetadataRange &)> SetStridedMetadataRangeFn
Callback function type for setting the strided metadata of a value.
std::optional< llvm::SmallDenseSet< unsigned > > computeRankReductionMask(ArrayRef< int64_t > originalShape, ArrayRef< int64_t > reducedShape, bool matchDynamic=false)
Given an originalShape and a reducedShape assumed to be a subset of originalShape with some 1 entries...
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
SliceVerificationResult isRankReducedType(ShapedType originalType, ShapedType candidateReducedType)
Check if originalType can be rank reduced to candidateReducedType type by dropping some dimensions wi...
ArrayAttr getReassociationIndicesAttribute(Builder &b, ArrayRef< ReassociationIndices > reassociation)
Wraps a list of reassociations in an ArrayAttr.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
bool isOneInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 1.
std::pair< SmallVector< int64_t >, SmallVector< Value > > decomposeMixedValues(ArrayRef< OpFoldResult > mixedValues)
Decompose a vector of mixed static or dynamic values into the corresponding pair of arrays.
function_ref< IntegerValueRange(Value)> GetIntRangeFn
Helper callback type to get the integer range of a value.
Move allocations into an allocation scope, if it is legal to move them (e.g.
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
Inline an AllocaScopeOp if either the direct parent is an allocation scope or it contains no allocati...
LogicalResult matchAndRewrite(AllocaScopeOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CollapseShapeOp op, PatternRewriter &rewriter) const override
A canonicalizer wrapper to replace SubViewOps.
void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp)
Return the canonical type of the result of a subview.
MemRefType operator()(SubViewOp op, ArrayRef< OpFoldResult > mixedOffsets, ArrayRef< OpFoldResult > mixedSizes, ArrayRef< OpFoldResult > mixedStrides)
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Represents a range (offset, size, and stride) where each element of the triple may be dynamic or stat...
OpFoldResult stride
OpFoldResult size
OpFoldResult offset
static SaturatedInteger wrap(int64_t v)
bool isValid
If set to "true", the slice bounds verification was successful.
std::string errorMessage
An error message that can be printed during op verification.
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.