MLIR 22.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Matchers.h"
16#include <optional>
17
18using namespace mlir;
19using namespace mlir::bufferization;
20
21//===----------------------------------------------------------------------===//
22// Helper functions
23//===----------------------------------------------------------------------===//
24
26 OpBuilder &b, Value value, MemRefType destType,
28 auto srcType = llvm::cast<MemRefType>(value.getType());
29
30 // Element type and rank must match.
31 if (srcType.getElementType() != destType.getElementType())
32 return failure();
33 if (srcType.getRank() != destType.getRank())
34 return failure();
35
36 // In case the affine maps are different, we may need to use a copy if we go
37 // from dynamic to static offset or stride (the canonicalization cannot know
38 // at this point that it is really cast compatible).
39 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
40 int64_t sourceOffset, targetOffset;
41 SmallVector<int64_t, 4> sourceStrides, targetStrides;
42 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
43 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
44 return false;
45 auto dynamicToStatic = [](int64_t a, int64_t b) {
46 return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
47 };
48 if (dynamicToStatic(sourceOffset, targetOffset))
49 return false;
50 for (auto it : zip(sourceStrides, targetStrides))
51 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
52 return false;
53 return true;
54 };
55
56 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
57 // ensure that we only generate casts that always succeed at runtime, we check
58 // a fix extra conditions in `isGuaranteedCastCompatible`.
59 if (memref::CastOp::areCastCompatible(srcType, destType) &&
60 isGuaranteedCastCompatible(srcType, destType)) {
61 Value casted = memref::CastOp::create(b, value.getLoc(), destType, value);
62 return casted;
63 }
64
65 auto loc = value.getLoc();
66 SmallVector<Value, 4> dynamicOperands;
67 for (int i = 0; i < destType.getRank(); ++i) {
68 if (destType.getShape()[i] != ShapedType::kDynamic)
69 continue;
70 Value size = memref::DimOp::create(b, loc, value, i);
71 dynamicOperands.push_back(size);
72 }
73
74 FailureOr<Value> copy =
75 options.createAlloc(b, loc, destType, dynamicOperands);
76 if (failed(copy))
77 return failure();
78 if (failed(options.createMemCpy(b, loc, value, *copy)))
79 return failure();
80 return copy;
81}
82
83/// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
84/// to_buffer op are different, a memref.cast is needed.
86 RewriterBase &rewriter, ToBufferOp toBuffer,
88 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
89 if (!bufferToTensor)
90 return failure();
91
92 Type srcType = bufferToTensor.getBuffer().getType();
93 Type destType = toBuffer.getType();
94
95 // Directly rewrite if the type did not change.
96 if (srcType == destType) {
97 rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
98 return success();
99 }
100
101 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
102 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
103 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
104
105 // Ranked memref -> Ranked memref cast.
106 if (rankedSrcType && rankedDestType) {
107 FailureOr<Value> replacement = castOrReallocMemRefValue(
108 rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
109 if (failed(replacement))
110 return failure();
111
112 rewriter.replaceOp(toBuffer, *replacement);
113 return success();
114 }
115
116 // Unranked memref -> Ranked memref cast: May require a copy.
117 // TODO: Not implemented at the moment.
118 if (unrankedSrcType && rankedDestType)
119 return failure();
120
121 // Unranked memref -> unranked memref cast
122 // Ranked memref -> unranked memref cast: No copy needed.
123 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
124 "expected that types are cast compatible");
125 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
126 bufferToTensor.getBuffer());
127 return success();
128}
129
131 OpBuilder &b, Location loc, Value shapedValue,
132 SmallVector<Value> &dynamicDims) {
133 auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
134 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
135 if (shapedType.isDynamicDim(i)) {
136 if (llvm::isa<MemRefType>(shapedType)) {
137 dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i));
138 } else {
139 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
140 dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i));
141 }
142 }
143 }
144}
145
146//===----------------------------------------------------------------------===//
147// AllocTensorOp
148//===----------------------------------------------------------------------===//
149
150LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
152 BufferizationState &state) {
153 OpBuilder::InsertionGuard g(rewriter);
154 Location loc = getLoc();
155
156 // Nothing to do for dead AllocTensorOps.
157 if (getOperation()->getUses().empty()) {
158 rewriter.eraseOp(getOperation());
159 return success();
160 }
161
162 // Get "copy" buffer.
163 Value copyBuffer;
164 if (getCopy()) {
165 FailureOr<Value> maybeCopyBuffer =
166 getBuffer(rewriter, getCopy(), options, state);
167 if (failed(maybeCopyBuffer))
168 return failure();
169 copyBuffer = *maybeCopyBuffer;
170 }
171
172 // Create memory allocation.
173 auto allocType = bufferization::getBufferType(getResult(), options, state);
174 if (failed(allocType))
175 return failure();
176 SmallVector<Value> dynamicDims = getDynamicSizes();
177 if (getCopy()) {
178 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
179 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
180 }
181 FailureOr<Value> alloc = options.createAlloc(
182 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
183 if (failed(alloc))
184 return failure();
185
186 // Create memory copy (if any).
187 if (getCopy()) {
188 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
189 return failure();
190 }
191
192 // Replace op.
193 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
194
195 return success();
196}
197
198bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
199 const AnalysisState &state) {
200 // AllocTensorOps do not write unless they have a `copy` value.
201 return static_cast<bool>(getCopy());
202}
203
204bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
205 const AnalysisState &state) {
206 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
207 "expected copy operand");
208 return true;
209}
210
211bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
212 const AnalysisState &state) {
213 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
214 "expected copy operand");
215 return false;
216}
217
218AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
219 const AnalysisState &state) {
220 // This is a new allocation. It does not alias with any other buffer.
221 return {};
222}
223
224FailureOr<BufferLikeType>
225AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
226 const BufferizationState &state,
227 SmallVector<Value> &invocationStack) {
228 assert(value == getResult() && "invalid value");
229
230 // Compute memory space of this allocation.
231 Attribute memorySpace;
232 if (getMemorySpace().has_value()) {
233 memorySpace = *getMemorySpace();
234 } else if (getCopy()) {
235 auto copyBufferType =
236 bufferization::detail::asMemRefType(bufferization::getBufferType(
237 getCopy(), options, state, invocationStack));
238 if (failed(copyBufferType))
239 return failure();
240 memorySpace = copyBufferType->getMemorySpace();
241 } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
242 memorySpace = *ms;
243 } else {
244 return getOperation()->emitError("could not infer memory space");
245 }
246
247 return cast<BufferLikeType>(
248 getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
249}
250
251LogicalResult AllocTensorOp::verify() {
252 if (getCopy() && !getDynamicSizes().empty())
253 return emitError("dynamic sizes not needed when copying a tensor");
254 if (!getCopy() && getType().getNumDynamicDims() != getDynamicSizes().size())
255 return emitError("expected ")
256 << getType().getNumDynamicDims() << " dynamic sizes";
257 if (getCopy() && getCopy().getType() != getType())
258 return emitError("expected that `copy` and return type match");
259 return success();
260}
261
262void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
263 RankedTensorType type, ValueRange dynamicSizes) {
264 build(builder, result, type, dynamicSizes, /*copy=*/Value(),
265 /*size_hint=*/Value(),
266 /*memory_space=*/IntegerAttr());
267}
268
269void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
270 RankedTensorType type, ValueRange dynamicSizes,
271 Value copy) {
272 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
273 /*memory_space=*/IntegerAttr());
274}
275
276void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
277 TensorType type, ValueRange dynamicSizes, Value copy,
278 IntegerAttr memorySpace) {
279 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
280 memorySpace);
281}
282
283namespace {
284/// Change the type of the result of a `bufferization.alloc_tensor` by making
285/// the result type statically sized along dimension that in the original
286/// operation where defined as dynamic, but the size was defined using a
287/// `constant` op. For example:
288///
289/// %c5 = arith.constant 5: index
290/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
291///
292/// to
293///
294/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
295struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
296 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
297
298 LogicalResult matchAndRewrite(AllocTensorOp op,
299 PatternRewriter &rewriter) const override {
300 if (op.getCopy())
301 return failure();
302 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
303 SmallVector<Value> newDynamicSizes;
304 unsigned int dynValCounter = 0;
305 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
306 if (!op.isDynamicDim(i))
307 continue;
308 Value value = op.getDynamicSizes()[dynValCounter++];
309 APInt intVal;
310 if (matchPattern(value, m_ConstantInt(&intVal))) {
311 int64_t dim = intVal.getSExtValue();
312 if (dim >= 0)
313 newShape[i] = intVal.getSExtValue();
314 else
315 newDynamicSizes.push_back(value);
316 } else {
317 newDynamicSizes.push_back(value);
318 }
319 }
320 RankedTensorType newType = RankedTensorType::get(
321 newShape, op.getType().getElementType(), op.getType().getEncoding());
322 if (newType == op.getType())
323 return failure();
324 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
325 newDynamicSizes, /*copy=*/Value());
326 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
327 return success();
328 }
329};
330
331struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
332 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
333
334 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
335 PatternRewriter &rewriter) const override {
336 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
337 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
338 if (!allocTensorOp || !maybeConstantIndex)
339 return failure();
340 if (*maybeConstantIndex < 0 ||
341 *maybeConstantIndex >= allocTensorOp.getType().getRank())
342 return failure();
343 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
344 return failure();
345 rewriter.replaceOp(
346 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
347 return success();
348 }
349};
350} // namespace
351
352void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
353 MLIRContext *ctx) {
354 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
355}
356
357LogicalResult AllocTensorOp::reifyResultShapes(
358 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
359 auto shapes = llvm::to_vector<4>(
360 llvm::map_range(llvm::seq<int64_t>(0, getType().getRank()),
361 [&](int64_t dim) -> OpFoldResult {
362 if (isDynamicDim(dim))
363 return getDynamicSize(builder, dim);
364 return builder.getIndexAttr(getStaticSize(dim));
365 }));
366 reifiedReturnShapes.emplace_back(std::move(shapes));
367 return success();
368}
369
370ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
372 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
373 parser.parseRParen())
374 return failure();
375 ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
377 if (copyKeyword.succeeded())
378 if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
379 parser.parseRParen())
380 return failure();
381 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
382 OpAsmParser::UnresolvedOperand sizeHintOperand;
383 if (sizeHintKeyword.succeeded())
384 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
385 return failure();
386 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
387 return failure();
388
389 TensorType type;
390 if (parser.parseCustomTypeWithFallback(type))
391 return failure();
392 result.addTypes(type);
393
394 Type indexType = parser.getBuilder().getIndexType();
395 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
396 return failure();
397 if (copyKeyword.succeeded())
398 if (parser.resolveOperand(copyOperand, type, result.operands))
399 return failure();
400 if (sizeHintKeyword.succeeded())
401 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
402 return failure();
403 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
405 {static_cast<int32_t>(dynamicSizesOperands.size()),
406 static_cast<int32_t>(copyKeyword.succeeded()),
407 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
408 return success();
409}
410
411void AllocTensorOp::print(OpAsmPrinter &p) {
412 p << "(" << getDynamicSizes() << ")";
413 if (getCopy())
414 p << " copy(" << getCopy() << ")";
415 if (getSizeHint())
416 p << " size_hint=" << getSizeHint();
417 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
418 AllocTensorOp::getOperandSegmentSizeAttr()});
419 p << " : ";
420 auto type = getResult().getType();
421 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
422 p.printStrippedAttrOrType(validType);
423 else
424 p << type;
425}
426
427Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
428 assert(isDynamicDim(idx) && "expected dynamic dim");
429 if (getCopy())
430 return tensor::DimOp::create(b, getLoc(), getCopy(), idx);
431 return getOperand(getIndexOfDynamicSize(idx));
432}
433
434//===----------------------------------------------------------------------===//
435// CloneOp
436//===----------------------------------------------------------------------===//
437
438OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
439 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
440}
441
442namespace {
443
444/// Merge the clone and its source (by converting the clone to a cast) when
445/// possible.
446struct SimplifyClones : public OpRewritePattern<CloneOp> {
447 using OpRewritePattern<CloneOp>::OpRewritePattern;
448
449 LogicalResult matchAndRewrite(CloneOp cloneOp,
450 PatternRewriter &rewriter) const override {
451 if (cloneOp.use_empty()) {
452 rewriter.eraseOp(cloneOp);
453 return success();
454 }
455
456 Value source = cloneOp.getInput();
457 if (source.getType() != cloneOp.getType() &&
458 !memref::CastOp::areCastCompatible({source.getType()},
459 {cloneOp.getType()}))
460 return failure();
461
462 // Aims to find the dealloc op for the canonical source
463 // which otherwise could prevent removal of unnecessary allocs.
464 Value canonicalSource = source;
465 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
466 canonicalSource.getDefiningOp())) {
467 if (canonicalSource != iface.getViewDest()) {
468 break;
469 }
470 canonicalSource = iface.getViewSource();
471 }
472
473 std::optional<Operation *> maybeCloneDeallocOp =
474 memref::findDealloc(cloneOp.getOutput());
475 // Skip if either of them has > 1 deallocate operations.
476 if (!maybeCloneDeallocOp.has_value())
477 return failure();
478 std::optional<Operation *> maybeSourceDeallocOp =
479 memref::findDealloc(canonicalSource);
480 if (!maybeSourceDeallocOp.has_value())
481 return failure();
482 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
483 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
484
485 // If both are deallocated in the same block, their in-block lifetimes
486 // might not fully overlap, so we cannot decide which one to drop.
487 if (cloneDeallocOp && sourceDeallocOp &&
488 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
489 return failure();
490
491 Block *currentBlock = cloneOp->getBlock();
492 Operation *redundantDealloc = nullptr;
493 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
494 redundantDealloc = cloneDeallocOp;
495 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
496 redundantDealloc = sourceDeallocOp;
497 }
498
499 if (!redundantDealloc)
500 return failure();
501
502 // Safety check that there are no other deallocations inbetween
503 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
504 // of source before the uses of the clone. With alias information, we could
505 // restrict this to only fail of the dealloc's operand is an alias
506 // of the source.
507 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
508 pos = pos->getNextNode()) {
509 // Bail if we run out of operations while looking for a deallocation op.
510 if (!pos)
511 return failure();
512 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
513 if (!effectInterface)
514 continue;
515 if (effectInterface.hasEffect<MemoryEffects::Free>())
516 return failure();
517 }
518
519 if (source.getType() != cloneOp.getType())
520 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
521 cloneOp.getType(), source);
522 rewriter.replaceOp(cloneOp, source);
523 rewriter.eraseOp(redundantDealloc);
524 return success();
525 }
526};
527
528} // namespace
529
530void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
531 MLIRContext *context) {
532 results.add<SimplifyClones>(context);
533}
534
535//===----------------------------------------------------------------------===//
536// DeallocTensorOp
537//===----------------------------------------------------------------------===//
538
539LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
541 BufferizationState &state) {
542 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
543 if (failed(buffer))
544 return failure();
545 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
546 rewriter.eraseOp(getOperation());
547 return success();
548}
549
550//===----------------------------------------------------------------------===//
551// MaterializeInDestinationOp
552//===----------------------------------------------------------------------===//
553
554bool MaterializeInDestinationOp::bufferizesToMemoryRead(
555 OpOperand &opOperand, const AnalysisState &state) {
556 return opOperand == getSourceMutable();
557}
558
559bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
560 OpOperand &opOperand, const AnalysisState &state) {
561 if (opOperand == getDestMutable()) {
562 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
563 return true;
564 }
565 return false;
566}
567
568bool MaterializeInDestinationOp::mustBufferizeInPlace(
569 OpOperand &opOperand, const AnalysisState &state) {
570 // The source is only read and not written, so it always bufferizes in-place
571 // by default. The destination is written and is forced to bufferize in-place
572 // (if it is a tensor).
573 return true;
574}
575
576AliasingValueList
577MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
578 const AnalysisState &state) {
579 if (opOperand == getDestMutable()) {
580 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
581 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
582 }
583 return {};
584}
585
586LogicalResult
587MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
589 BufferizationState &state) {
590 bool tensorDest = isa<TensorType>(getDest().getType());
591 Value buffer;
592 if (tensorDest) {
593 FailureOr<Value> maybeBuffer =
594 getBuffer(rewriter, getDest(), options, state);
595 if (failed(maybeBuffer))
596 return failure();
597 buffer = *maybeBuffer;
598 } else {
599 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
600 buffer = getDest();
601 }
602 auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
603 if (failed(srcBuffer))
604 return failure();
605 if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
606 return failure();
607 replaceOpWithBufferizedValues(rewriter, getOperation(),
608 tensorDest ? ValueRange(buffer) : ValueRange());
609 return success();
610}
611
612bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
613 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
614 // As elements are copied from the "source" buffer to the "dest" buffer,
615 // already copied elements are not read a second time.
616 return true;
617}
618
619LogicalResult MaterializeInDestinationOp::reifyResultShapes(
620 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
621 if (getOperation()->getNumResults() == 1) {
622 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
623 reifiedReturnShapes.resize(1,
625 reifiedReturnShapes[0] =
626 tensor::getMixedSizes(builder, getLoc(), getDest());
627 }
628 return success();
629}
630
631Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
632 Location loc) {
633 if (isa<TensorType>(getDest().getType())) {
634 // The subset is the entire destination tensor.
635 return getDest();
636 }
637
638 // The "restrict" attribute is transferred from this op to the newly created
639 // to_tensor op. If this op does not the "restrict" attribute, the subset
640 // extraction cannot be built because there is no guarantee that there is no
641 // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
642 if (!getRestrict())
643 return {};
644
645 // Build a bufferization.to_tensor op.
646 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
647 assert(getRestrict() &&
648 "expected that ops with memrefs dest have 'restrict'");
649 setRestrict(false);
650 return ToTensorOp::create(
651 builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()),
652 getDest(),
653 /*restrict=*/true, getWritable());
654}
655
656bool MaterializeInDestinationOp::isEquivalentSubset(
657 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
658 return equivalenceFn(getDest(), candidate);
659}
660
662MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
663 return {getDest()};
664}
665
666OpOperand &MaterializeInDestinationOp::getSourceOperand() {
667 return getOperation()->getOpOperand(0) /*source*/;
668}
669
670bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
671 SubsetOpInterface subsetOp,
672 function_ref<bool(Value, Value)> equivalenceFn) {
673 return false;
674}
675
676bool MaterializeInDestinationOp::operatesOnDisjointSubset(
677 SubsetOpInterface subsetOp,
678 function_ref<bool(Value, Value)> equivalenceFn) {
679 return false;
680}
681
682LogicalResult MaterializeInDestinationOp::verify() {
683 if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
684 return emitOpError("'dest' must be a tensor or a memref");
685 if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
686 if (getOperation()->getNumResults() != 1)
687 return emitOpError("tensor 'dest' implies exactly one tensor result");
688 if (destType != getResult().getType())
689 return emitOpError("result and 'dest' types must match");
690 }
691 if (isa<BaseMemRefType>(getDest().getType()) &&
692 getOperation()->getNumResults() != 0)
693 return emitOpError("memref 'dest' implies zero results");
694 if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
695 return emitOpError("'restrict' is valid only for memref destinations");
696 if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
697 return emitOpError("'writable' must be specified if and only if the "
698 "destination is of memref type");
699 TensorType srcType = getSource().getType();
700 ShapedType destType = cast<ShapedType>(getDest().getType());
701 if (srcType.hasRank() != destType.hasRank())
702 return emitOpError("source/destination shapes are incompatible");
703 if (srcType.hasRank()) {
704 if (srcType.getRank() != destType.getRank())
705 return emitOpError("rank mismatch between source and destination shape");
706 for (auto [src, dest] :
707 llvm::zip(srcType.getShape(), destType.getShape())) {
708 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
709 // Cannot verify dynamic dimension size. Assume that that they match at
710 // runtime.
711 continue;
712 }
713 if (src != dest)
714 return emitOpError("source/destination shapes are incompatible");
715 }
716 }
717 return success();
718}
719
720void MaterializeInDestinationOp::build(OpBuilder &builder,
721 OperationState &state, Value source,
722 Value dest) {
723 auto destTensorType = dyn_cast<TensorType>(dest.getType());
724 build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
725 source, dest);
726}
727
728bool MaterializeInDestinationOp::isWritable(Value value,
729 const AnalysisState &state) {
730 return isa<TensorType>(getDest().getType()) ? true : getWritable();
731}
732
733MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
734 return getDestMutable();
735}
736
737void MaterializeInDestinationOp::getEffects(
739 &effects) {
740 if (isa<BaseMemRefType>(getDest().getType()))
741 effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
743}
744
745//===----------------------------------------------------------------------===//
746// ToTensorOp
747//===----------------------------------------------------------------------===//
748
749bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
750 return getWritable();
751}
752
753OpFoldResult ToTensorOp::fold(FoldAdaptor) {
754 if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
755 // Approximate alias analysis by conservatively folding only when no there
756 // is no interleaved operation.
757 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
758 toBuffer->getNextNode() == this->getOperation())
759 return toBuffer.getTensor();
760 return {};
761}
762
763namespace {
764struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
765 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
766
767 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
768 PatternRewriter &rewriter) const override {
769 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
770 if (!memrefToTensorOp)
771 return failure();
772
773 rewriter.replaceOpWithNewOp<memref::DimOp>(
774 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
775 return success();
776 }
777};
778} // namespace
779
780void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
781 MLIRContext *context) {
782 results.add<DimOfToTensorFolder>(context);
783}
784
785//===----------------------------------------------------------------------===//
786// ToBufferOp
787//===----------------------------------------------------------------------===//
788
789OpFoldResult ToBufferOp::fold(FoldAdaptor) {
790 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
791 if (memrefToTensor.getBuffer().getType() == getType())
792 return memrefToTensor.getBuffer();
793 return {};
794}
795
796namespace {
797
798/// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
799struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
800 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
801
802 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
803 PatternRewriter &rewriter) const final {
804 auto tensorCastOperand =
805 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
806 if (!tensorCastOperand)
807 return failure();
808 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
809 tensorCastOperand.getOperand().getType());
810 if (!srcTensorType)
811 return failure();
812 auto currentOutputMemRefType =
813 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
814 if (!currentOutputMemRefType)
815 return failure();
816
817 auto memrefType = currentOutputMemRefType.cloneWith(
818 srcTensorType.getShape(), srcTensorType.getElementType());
819 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
820 tensorCastOperand.getOperand(),
821 toBuffer.getReadOnly());
822 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
823 memref);
824 return success();
825 }
826};
827
828/// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
829/// cast if necessary.
830struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
831 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
832
833 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
834 PatternRewriter &rewriter) const final {
835 BufferizationOptions options;
836 options.bufferAlignment = 0;
837 return foldToBufferToTensorPair(rewriter, toBuffer, options);
838 }
839};
840
841/// Fold a load on a to_buffer operation into an tensor.extract on the
842/// corresponding tensor.
843struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
844 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
845
846 LogicalResult matchAndRewrite(memref::LoadOp load,
847 PatternRewriter &rewriter) const override {
848 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
849 if (!toBuffer)
850 return failure();
851
852 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
853 load.getIndices());
854 return success();
855 }
856};
857
858/// Fold dim of a to_buffer into the dim of the tensor.
859struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
860 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
861
862 LogicalResult matchAndRewrite(memref::DimOp dimOp,
863 PatternRewriter &rewriter) const override {
864 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
865 if (!castOp)
866 return failure();
867 Value newSource = castOp.getOperand();
868 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
869 dimOp.getIndex());
870 return success();
871 }
872};
873
874} // namespace
875
876void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
877 MLIRContext *context) {
878 results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
879 ToBufferToTensorFolding>(context);
880}
881
882LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
884 BufferizationState &state) {
885 // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
886 (void)foldToBufferToTensorPair(rewriter, *this, options);
887 // Note: The return value of `bufferize` indicates whether there was an error
888 // or not. (And not whether the pattern matched or not.)
889 return success();
890}
891
892std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
893 Value alloc) {
894 return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
895 .getOperation();
896}
897
898std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
899 return CloneOp::create(builder, alloc.getLoc(), alloc).getResult();
900}
901
902//===----------------------------------------------------------------------===//
903// DeallocOp
904//===----------------------------------------------------------------------===//
905
906LogicalResult DeallocOp::inferReturnTypes(
907 MLIRContext *context, std::optional<::mlir::Location> location,
908 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
909 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
910 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
911 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
912 IntegerType::get(context, 1));
913 return success();
914}
915
916LogicalResult DeallocOp::verify() {
917 if (getMemrefs().size() != getConditions().size())
918 return emitOpError(
919 "must have the same number of conditions as memrefs to deallocate");
920 if (getRetained().size() != getUpdatedConditions().size())
921 return emitOpError("must have the same number of updated conditions "
922 "(results) as retained operands");
923 return success();
924}
925
926static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
927 ValueRange memrefs,
928 ValueRange conditions,
929 PatternRewriter &rewriter) {
930 if (deallocOp.getMemrefs() == memrefs &&
931 deallocOp.getConditions() == conditions)
932 return failure();
933
934 rewriter.modifyOpInPlace(deallocOp, [&]() {
935 deallocOp.getMemrefsMutable().assign(memrefs);
936 deallocOp.getConditionsMutable().assign(conditions);
937 });
938 return success();
939}
940
941namespace {
942
943/// Remove duplicate values in the list of memrefs to be deallocated. We need to
944/// make sure the corresponding condition value is updated accordingly since
945/// their two conditions might not cover the same set of cases. In that case, we
946/// have to combine them (by computing the disjunction of them).
947/// Example:
948/// ```mlir
949/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
950/// ```
951/// is canonicalized to
952/// ```mlir
953/// %0 = arith.ori %arg1, %arg2 : i1
954/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
955/// ```
956struct DeallocRemoveDuplicateDeallocMemrefs
957 : public OpRewritePattern<DeallocOp> {
958 using OpRewritePattern<DeallocOp>::OpRewritePattern;
959
960 LogicalResult matchAndRewrite(DeallocOp deallocOp,
961 PatternRewriter &rewriter) const override {
962 // Unique memrefs to be deallocated.
963 DenseMap<Value, unsigned> memrefToCondition;
964 SmallVector<Value> newMemrefs, newConditions;
965 for (auto [i, memref, cond] :
966 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
967 if (memrefToCondition.count(memref)) {
968 // If the dealloc conditions don't match, we need to make sure that the
969 // dealloc happens on the union of cases.
970 Value &newCond = newConditions[memrefToCondition[memref]];
971 if (newCond != cond)
972 newCond =
973 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
974 } else {
975 memrefToCondition.insert({memref, newConditions.size()});
976 newMemrefs.push_back(memref);
977 newConditions.push_back(cond);
978 }
979 }
980
981 // Return failure if we don't change anything such that we don't run into an
982 // infinite loop of pattern applications.
983 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
984 rewriter);
985 }
986};
987
988/// Remove duplicate values in the list of retained memrefs. We need to make
989/// sure the corresponding result condition value is replaced properly.
990/// Example:
991/// ```mlir
992/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
993/// ```
994/// is canonicalized to
995/// ```mlir
996/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
997/// ```
998struct DeallocRemoveDuplicateRetainedMemrefs
999 : public OpRewritePattern<DeallocOp> {
1000 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1001
1002 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1003 PatternRewriter &rewriter) const override {
1004 // Unique retained values
1006 SmallVector<Value> newRetained;
1007 SmallVector<unsigned> resultReplacementIdx;
1008 unsigned i = 0;
1009 for (auto retained : deallocOp.getRetained()) {
1010 if (seen.count(retained)) {
1011 resultReplacementIdx.push_back(seen[retained]);
1012 continue;
1013 }
1014
1015 seen[retained] = i;
1016 newRetained.push_back(retained);
1017 resultReplacementIdx.push_back(i++);
1018 }
1019
1020 // Return failure if we don't change anything such that we don't run into an
1021 // infinite loop of pattern applications.
1022 if (newRetained.size() == deallocOp.getRetained().size())
1023 return failure();
1024
1025 // We need to create a new op because the number of results is always the
1026 // same as the number of condition operands.
1027 auto newDeallocOp =
1028 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1029 deallocOp.getConditions(), newRetained);
1030 SmallVector<Value> replacements(
1031 llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1032 return newDeallocOp.getUpdatedConditions()[idx];
1033 }));
1034 rewriter.replaceOp(deallocOp, replacements);
1035 return success();
1036 }
1037};
1038
1039/// Erase deallocation operations where the variadic list of memrefs to
1040/// deallocate is empty. Example:
1041/// ```mlir
1042/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1043/// ```
1044struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1045 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1046
1047 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1048 PatternRewriter &rewriter) const override {
1049 if (deallocOp.getMemrefs().empty()) {
1050 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1051 rewriter.getBoolAttr(false));
1052 rewriter.replaceOp(
1053 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1054 constFalse));
1055 return success();
1056 }
1057 return failure();
1058 }
1059};
1060
1061/// Removes memrefs from the deallocation list if their associated condition is
1062/// always 'false'.
1063///
1064/// Example:
1065/// ```
1066/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1067/// if (%arg2, %false)
1068/// ```
1069/// becomes
1070/// ```
1071/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1072/// ```
1073struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1074 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1075
1076 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1077 PatternRewriter &rewriter) const override {
1078 SmallVector<Value> newMemrefs, newConditions;
1079 for (auto [memref, cond] :
1080 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1081 if (!matchPattern(cond, m_Zero())) {
1082 newMemrefs.push_back(memref);
1083 newConditions.push_back(cond);
1084 }
1085 }
1086
1087 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1088 rewriter);
1089 }
1090};
1091
1092/// The `memref.extract_strided_metadata` is often inserted to get the base
1093/// memref if the operand is not already guaranteed to be the result of a memref
1094/// allocation operation. This canonicalization pattern removes this extraction
1095/// operation if the operand is now produced by an allocation operation (e.g.,
1096/// due to other canonicalizations simplifying the IR).
1097///
1098/// Example:
1099/// ```mlir
1100/// %alloc = memref.alloc() : memref<2xi32>
1101/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1102/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1103/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1104/// ```
1105/// is canonicalized to
1106/// ```mlir
1107/// %alloc = memref.alloc() : memref<2xi32>
1108/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1109/// ```
1110struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1111 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1112
1113 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1114 PatternRewriter &rewriter) const override {
1115 SmallVector<Value> newMemrefs(
1116 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1117 auto extractStridedOp =
1118 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1119 if (!extractStridedOp)
1120 return memref;
1121 Value allocMemref = extractStridedOp.getOperand();
1122 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1123 if (!allocOp)
1124 return memref;
1125 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1126 return allocMemref;
1127 return memref;
1128 }));
1129
1130 return updateDeallocIfChanged(deallocOp, newMemrefs,
1131 deallocOp.getConditions(), rewriter);
1132 }
1133};
1134
1135/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1136/// other user of the allocated value and the allocating operation can be safely
1137/// removed. If the same value is present multiple times, this pattern relies on
1138/// other canonicalization patterns to remove the duplicate first.
1139///
1140/// Example:
1141/// ```mlir
1142/// %alloc = memref.alloc() : memref<2xi32>
1143/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1144/// ```
1145/// is canonicalized to
1146/// ```mlir
1147/// bufferization.dealloc (%arg0 : ...) if (%true)
1148/// ```
1149struct RemoveAllocDeallocPairWhenNoOtherUsers
1150 : public OpRewritePattern<DeallocOp> {
1151 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1152
1153 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1154 PatternRewriter &rewriter) const override {
1155 SmallVector<Value> newMemrefs, newConditions;
1156 SmallVector<Operation *> toDelete;
1157 for (auto [memref, cond] :
1158 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1159 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1160 // Check that it is indeed an allocate effect, that the op has no other
1161 // side effects (which would not allow us to remove the op), and that
1162 // there are no other users.
1163 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1165 memref.hasOneUse()) {
1166 toDelete.push_back(allocOp);
1167 continue;
1168 }
1169 }
1170
1171 newMemrefs.push_back(memref);
1172 newConditions.push_back(cond);
1173 }
1174
1175 if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1176 rewriter)))
1177 return failure();
1178
1179 for (Operation *op : toDelete)
1180 rewriter.eraseOp(op);
1181
1182 return success();
1183 }
1184};
1185
1186} // anonymous namespace
1187
1188void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1189 MLIRContext *context) {
1191}
1192
1195 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1196 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1197 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1198 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1199}
1200
1201//===----------------------------------------------------------------------===//
1202// TableGen'd op method definitions
1203//===----------------------------------------------------------------------===//
1204
1205#define GET_OP_CLASSES
1206#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:60
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:45
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.