MLIR 23.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Matchers.h"
17#include "llvm/ADT/SmallVectorExtras.h"
18#include <optional>
19
20using namespace mlir;
21using namespace mlir::bufferization;
22
23//===----------------------------------------------------------------------===//
24// Helper functions
25//===----------------------------------------------------------------------===//
26
28 OpBuilder &b, Value value, MemRefType destType,
30 auto srcType = llvm::cast<MemRefType>(value.getType());
31
32 // Element type and rank must match.
33 if (srcType.getElementType() != destType.getElementType())
34 return failure();
35 if (srcType.getRank() != destType.getRank())
36 return failure();
37
38 // In case the affine maps are different, we may need to use a copy if we go
39 // from dynamic to static offset or stride (the canonicalization cannot know
40 // at this point that it is really cast compatible).
41 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
42 int64_t sourceOffset, targetOffset;
43 SmallVector<int64_t, 4> sourceStrides, targetStrides;
44 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
45 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
46 return false;
47 auto dynamicToStatic = [](int64_t a, int64_t b) {
48 return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
49 };
50 if (dynamicToStatic(sourceOffset, targetOffset))
51 return false;
52 for (auto it : zip(sourceStrides, targetStrides))
53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
54 return false;
55 return true;
56 };
57
58 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
59 // ensure that we only generate casts that always succeed at runtime, we check
60 // a fix extra conditions in `isGuaranteedCastCompatible`.
61 if (memref::CastOp::areCastCompatible(srcType, destType) &&
62 isGuaranteedCastCompatible(srcType, destType)) {
63 Value casted = memref::CastOp::create(b, value.getLoc(), destType, value);
64 return casted;
65 }
66
67 auto loc = value.getLoc();
68 SmallVector<Value, 4> dynamicOperands;
69 for (int i = 0; i < destType.getRank(); ++i) {
70 if (destType.getShape()[i] != ShapedType::kDynamic)
71 continue;
72 Value size = memref::DimOp::create(b, loc, value, i);
73 dynamicOperands.push_back(size);
74 }
75
76 FailureOr<Value> copy =
77 options.createAlloc(b, loc, destType, dynamicOperands);
78 if (failed(copy))
79 return failure();
80 if (failed(options.createMemCpy(b, loc, value, *copy)))
81 return failure();
82 return copy;
83}
84
85/// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
86/// to_buffer op are different, a memref.cast is needed.
88 RewriterBase &rewriter, ToBufferOp toBuffer,
90 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
91 if (!bufferToTensor)
92 return failure();
93
94 Type srcType = bufferToTensor.getBuffer().getType();
95 Type destType = toBuffer.getType();
96
97 // Directly rewrite if the type did not change.
98 if (srcType == destType) {
99 rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
100 return success();
101 }
102
103 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
104 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
105 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
106
107 // Ranked memref -> Ranked memref cast.
108 if (rankedSrcType && rankedDestType) {
109 FailureOr<Value> replacement = castOrReallocMemRefValue(
110 rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
111 if (failed(replacement))
112 return failure();
113
114 rewriter.replaceOp(toBuffer, *replacement);
115 return success();
116 }
117
118 // Unranked memref -> Ranked memref cast: May require a copy.
119 // TODO: Not implemented at the moment.
120 if (unrankedSrcType && rankedDestType)
121 return failure();
122
123 // Unranked memref -> unranked memref cast
124 // Ranked memref -> unranked memref cast: No copy needed.
125 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126 "expected that types are cast compatible");
127 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
128 bufferToTensor.getBuffer());
129 return success();
130}
131
133 OpBuilder &b, Location loc, Value shapedValue,
134 SmallVector<Value> &dynamicDims) {
135 auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
136 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137 if (shapedType.isDynamicDim(i)) {
138 if (llvm::isa<MemRefType>(shapedType)) {
139 dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i));
140 } else {
141 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
142 dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i));
143 }
144 }
145 }
146}
147
148//===----------------------------------------------------------------------===//
149// AllocTensorOp
150//===----------------------------------------------------------------------===//
151
152LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154 BufferizationState &state) {
155 OpBuilder::InsertionGuard g(rewriter);
156 Location loc = getLoc();
157
158 // Nothing to do for dead AllocTensorOps.
159 if (getOperation()->getUses().empty()) {
160 rewriter.eraseOp(getOperation());
161 return success();
162 }
163
164 // Get "copy" buffer.
165 Value copyBuffer;
166 if (getCopy()) {
167 FailureOr<Value> maybeCopyBuffer =
168 getBuffer(rewriter, getCopy(), options, state);
169 if (failed(maybeCopyBuffer))
170 return failure();
171 copyBuffer = *maybeCopyBuffer;
172 }
173
174 // Create memory allocation.
175 auto allocType = bufferization::getBufferType(getResult(), options, state);
176 if (failed(allocType))
177 return failure();
178 SmallVector<Value> dynamicDims = getDynamicSizes();
179 if (getCopy()) {
180 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
181 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
182 }
183 FailureOr<Value> alloc = options.createAlloc(
184 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
185 if (failed(alloc))
186 return failure();
187
188 // Create memory copy (if any).
189 if (getCopy()) {
190 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
191 return failure();
192 }
193
194 // Replace op.
195 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
196
197 return success();
198}
199
200bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
201 const AnalysisState &state) {
202 // AllocTensorOps do not write unless they have a `copy` value.
203 return static_cast<bool>(getCopy());
204}
205
206bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
207 const AnalysisState &state) {
208 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
209 "expected copy operand");
210 return true;
211}
212
213bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
214 const AnalysisState &state) {
215 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
216 "expected copy operand");
217 return false;
218}
219
220AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
221 const AnalysisState &state) {
222 // This is a new allocation. It does not alias with any other buffer.
223 return {};
224}
225
226FailureOr<BufferLikeType>
227AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
228 const BufferizationState &state,
229 SmallVector<Value> &invocationStack) {
230 assert(value == getResult() && "invalid value");
231
232 // Compute memory space of this allocation.
233 Attribute memorySpace;
234 if (getMemorySpace().has_value()) {
235 memorySpace = *getMemorySpace();
236 } else if (getCopy()) {
237 auto copyBufferType =
238 bufferization::detail::asMemRefType(bufferization::getBufferType(
239 getCopy(), options, state, invocationStack));
240 if (failed(copyBufferType))
241 return failure();
242 memorySpace = copyBufferType->getMemorySpace();
243 } else if (auto ms = options.defaultMemorySpaceFn(
244 cast<TensorLikeType>(getType()))) {
245 memorySpace = *ms;
246 } else {
247 return getOperation()->emitError("could not infer memory space");
248 }
249
250 return cast<BufferLikeType>(
251 getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
252}
253
254LogicalResult AllocTensorOp::verify() {
255 if (getCopy() && !getDynamicSizes().empty())
256 return emitError("dynamic sizes not needed when copying a tensor");
257 if (!getCopy() && failed(verifyDynamicDimensionCount(
258 getOperation(), getType(), getDynamicSizes())))
259 return failure();
260 if (getCopy() && getCopy().getType() != getType())
261 return emitError("expected that `copy` and return type match");
262 return success();
263}
264
265void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
266 RankedTensorType type, ValueRange dynamicSizes) {
267 build(builder, result, type, dynamicSizes, /*copy=*/Value(),
268 /*size_hint=*/Value(),
269 /*memory_space=*/IntegerAttr());
270}
271
272void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
273 RankedTensorType type, ValueRange dynamicSizes,
274 Value copy) {
275 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
276 /*memory_space=*/IntegerAttr());
277}
278
279void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
280 TensorType type, ValueRange dynamicSizes, Value copy,
281 IntegerAttr memorySpace) {
282 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
283 memorySpace);
284}
285
286namespace {
287/// Change the type of the result of a `bufferization.alloc_tensor` by making
288/// the result type statically sized along dimension that in the original
289/// operation where defined as dynamic, but the size was defined using a
290/// `constant` op. For example:
291///
292/// %c5 = arith.constant 5: index
293/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
294///
295/// to
296///
297/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
298struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
299 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
300
301 LogicalResult matchAndRewrite(AllocTensorOp op,
302 PatternRewriter &rewriter) const override {
303 if (op.getCopy())
304 return failure();
305 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
306 SmallVector<Value> newDynamicSizes;
307 unsigned int dynValCounter = 0;
308 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
309 if (!op.isDynamicDim(i))
310 continue;
311 Value value = op.getDynamicSizes()[dynValCounter++];
312 APInt intVal;
313 if (matchPattern(value, m_ConstantInt(&intVal))) {
314 int64_t dim = intVal.getSExtValue();
315 if (dim >= 0)
316 newShape[i] = intVal.getSExtValue();
317 else
318 newDynamicSizes.push_back(value);
319 } else {
320 newDynamicSizes.push_back(value);
321 }
322 }
323 RankedTensorType newType = RankedTensorType::get(
324 newShape, op.getType().getElementType(), op.getType().getEncoding());
325 if (newType == op.getType())
326 return failure();
327 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
328 newDynamicSizes, /*copy=*/Value());
329 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
330 return success();
331 }
332};
333
334struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
335 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
336
337 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
338 PatternRewriter &rewriter) const override {
339 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
340 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
341 if (!allocTensorOp || !maybeConstantIndex)
342 return failure();
343 if (*maybeConstantIndex < 0 ||
344 *maybeConstantIndex >= allocTensorOp.getType().getRank())
345 return failure();
346 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
347 return failure();
348 rewriter.replaceOp(
349 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
350 return success();
351 }
352};
353} // namespace
354
355void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
356 MLIRContext *ctx) {
357 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
358}
359
360LogicalResult AllocTensorOp::reifyResultShapes(
361 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
362 auto shapes =
363 llvm::map_to_vector<4>(llvm::seq<int64_t>(0, getType().getRank()),
364 [&](int64_t dim) -> OpFoldResult {
365 if (isDynamicDim(dim))
366 return getDynamicSize(builder, dim);
367 return builder.getIndexAttr(getStaticSize(dim));
368 });
369 reifiedReturnShapes.emplace_back(std::move(shapes));
370 return success();
371}
372
373ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
375 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
376 parser.parseRParen())
377 return failure();
378 ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
380 if (copyKeyword.succeeded())
381 if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
382 parser.parseRParen())
383 return failure();
384 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
385 OpAsmParser::UnresolvedOperand sizeHintOperand;
386 if (sizeHintKeyword.succeeded())
387 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
388 return failure();
389 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
390 return failure();
391
392 TensorType type;
393 if (parser.parseCustomTypeWithFallback(type))
394 return failure();
395 result.addTypes(type);
396
397 Type indexType = parser.getBuilder().getIndexType();
398 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
399 return failure();
400 if (copyKeyword.succeeded())
401 if (parser.resolveOperand(copyOperand, type, result.operands))
402 return failure();
403 if (sizeHintKeyword.succeeded())
404 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
405 return failure();
406 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
408 {static_cast<int32_t>(dynamicSizesOperands.size()),
409 static_cast<int32_t>(copyKeyword.succeeded()),
410 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
411 return success();
412}
413
414void AllocTensorOp::print(OpAsmPrinter &p) {
415 p << "(" << getDynamicSizes() << ")";
416 if (getCopy())
417 p << " copy(" << getCopy() << ")";
418 if (getSizeHint())
419 p << " size_hint=" << getSizeHint();
420 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
421 AllocTensorOp::getOperandSegmentSizeAttr()});
422 p << " : ";
423 auto type = getResult().getType();
424 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
425 p.printStrippedAttrOrType(validType);
426 else
427 p << type;
428}
429
430Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
431 assert(isDynamicDim(idx) && "expected dynamic dim");
432 if (getCopy())
433 return tensor::DimOp::create(b, getLoc(), getCopy(), idx);
434 return getOperand(getIndexOfDynamicSize(idx));
435}
436
437//===----------------------------------------------------------------------===//
438// CloneOp
439//===----------------------------------------------------------------------===//
440
441OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
442 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
443}
444
445namespace {
446
447/// Merge the clone and its source (by converting the clone to a cast) when
448/// possible.
449struct SimplifyClones : public OpRewritePattern<CloneOp> {
450 using OpRewritePattern<CloneOp>::OpRewritePattern;
451
452 LogicalResult matchAndRewrite(CloneOp cloneOp,
453 PatternRewriter &rewriter) const override {
454 if (cloneOp.use_empty()) {
455 rewriter.eraseOp(cloneOp);
456 return success();
457 }
458
459 Value source = cloneOp.getInput();
460 if (source.getType() != cloneOp.getType() &&
461 !memref::CastOp::areCastCompatible({source.getType()},
462 {cloneOp.getType()}))
463 return failure();
464
465 // Aims to find the dealloc op for the canonical source
466 // which otherwise could prevent removal of unnecessary allocs.
467 Value canonicalSource = source;
468 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
469 canonicalSource.getDefiningOp())) {
470 if (canonicalSource != iface.getViewDest()) {
471 break;
472 }
473 canonicalSource = iface.getViewSource();
474 }
475
476 std::optional<Operation *> maybeCloneDeallocOp =
477 memref::findDealloc(cloneOp.getOutput());
478 // Skip if either of them has > 1 deallocate operations.
479 if (!maybeCloneDeallocOp.has_value())
480 return failure();
481 std::optional<Operation *> maybeSourceDeallocOp =
482 memref::findDealloc(canonicalSource);
483 if (!maybeSourceDeallocOp.has_value())
484 return failure();
485 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
486 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
487
488 // If both are deallocated in the same block, their in-block lifetimes
489 // might not fully overlap, so we cannot decide which one to drop.
490 if (cloneDeallocOp && sourceDeallocOp &&
491 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
492 return failure();
493
494 Block *currentBlock = cloneOp->getBlock();
495 Operation *redundantDealloc = nullptr;
496 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
497 redundantDealloc = cloneDeallocOp;
498 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
499 redundantDealloc = sourceDeallocOp;
500 }
501
502 if (!redundantDealloc)
503 return failure();
504
505 // Safety check that there are no other deallocations inbetween
506 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
507 // of source before the uses of the clone. With alias information, we could
508 // restrict this to only fail of the dealloc's operand is an alias
509 // of the source.
510 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
511 pos = pos->getNextNode()) {
512 // Bail if we run out of operations while looking for a deallocation op.
513 if (!pos)
514 return failure();
515 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
516 if (!effectInterface)
517 continue;
518 if (effectInterface.hasEffect<MemoryEffects::Free>())
519 return failure();
520 }
521
522 if (source.getType() != cloneOp.getType())
523 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
524 cloneOp.getType(), source);
525 rewriter.replaceOp(cloneOp, source);
526 rewriter.eraseOp(redundantDealloc);
527 return success();
528 }
529};
530
531} // namespace
532
533void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
534 MLIRContext *context) {
535 results.add<SimplifyClones>(context);
536}
537
538//===----------------------------------------------------------------------===//
539// DeallocTensorOp
540//===----------------------------------------------------------------------===//
541
542LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
544 BufferizationState &state) {
545 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
546 if (failed(buffer))
547 return failure();
548 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
549 rewriter.eraseOp(getOperation());
550 return success();
551}
552
553//===----------------------------------------------------------------------===//
554// MaterializeInDestinationOp
555//===----------------------------------------------------------------------===//
556
557bool MaterializeInDestinationOp::bufferizesToMemoryRead(
558 OpOperand &opOperand, const AnalysisState &state) {
559 return opOperand == getSourceMutable();
560}
561
562bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
563 OpOperand &opOperand, const AnalysisState &state) {
564 if (opOperand == getDestMutable()) {
565 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
566 return true;
567 }
568 return false;
569}
570
571bool MaterializeInDestinationOp::mustBufferizeInPlace(
572 OpOperand &opOperand, const AnalysisState &state) {
573 // The source is only read and not written, so it always bufferizes in-place
574 // by default. The destination is written and is forced to bufferize in-place
575 // (if it is a tensor).
576 return true;
577}
578
579AliasingValueList
580MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
581 const AnalysisState &state) {
582 if (opOperand == getDestMutable()) {
583 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
584 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
585 }
586 return {};
587}
588
589LogicalResult
590MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
592 BufferizationState &state) {
593 bool tensorDest = isa<TensorType>(getDest().getType());
594 Value buffer;
595 if (tensorDest) {
596 FailureOr<Value> maybeBuffer =
597 getBuffer(rewriter, getDest(), options, state);
598 if (failed(maybeBuffer))
599 return failure();
600 buffer = *maybeBuffer;
601 } else {
602 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
603 buffer = getDest();
604 }
605 auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
606 if (failed(srcBuffer))
607 return failure();
608 if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
609 return failure();
610 replaceOpWithBufferizedValues(rewriter, getOperation(),
611 tensorDest ? ValueRange(buffer) : ValueRange());
612 return success();
613}
614
615bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
616 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
617 // As elements are copied from the "source" buffer to the "dest" buffer,
618 // already copied elements are not read a second time.
619 return true;
620}
621
622LogicalResult MaterializeInDestinationOp::reifyResultShapes(
623 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
624 if (getOperation()->getNumResults() == 1) {
625 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
626 reifiedReturnShapes.resize(1,
628 reifiedReturnShapes[0] =
629 tensor::getMixedSizes(builder, getLoc(), getDest());
630 }
631 return success();
632}
633
634Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
635 Location loc) {
636 if (isa<TensorType>(getDest().getType())) {
637 // The subset is the entire destination tensor.
638 return getDest();
639 }
640
641 // The "restrict" attribute is transferred from this op to the newly created
642 // to_tensor op. If this op does not the "restrict" attribute, the subset
643 // extraction cannot be built because there is no guarantee that there is no
644 // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
645 if (!getRestrict())
646 return {};
647
648 // Build a bufferization.to_tensor op.
649 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
650 assert(getRestrict() &&
651 "expected that ops with memrefs dest have 'restrict'");
652 setRestrict(false);
653 return ToTensorOp::create(
654 builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()),
655 getDest(),
656 /*restrict=*/true, getWritable());
657}
658
659bool MaterializeInDestinationOp::isEquivalentSubset(
660 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
661 return equivalenceFn(getDest(), candidate);
662}
663
665MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
666 return {getDest()};
667}
668
669OpOperand &MaterializeInDestinationOp::getSourceOperand() {
670 return getOperation()->getOpOperand(0) /*source*/;
671}
672
673bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
674 SubsetOpInterface subsetOp,
675 function_ref<bool(Value, Value)> equivalenceFn) {
676 return false;
677}
678
679bool MaterializeInDestinationOp::operatesOnDisjointSubset(
680 SubsetOpInterface subsetOp,
681 function_ref<bool(Value, Value)> equivalenceFn) {
682 return false;
683}
684
685LogicalResult MaterializeInDestinationOp::verify() {
686 if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
687 return emitOpError("'dest' must be a tensor or a memref");
688 if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
689 if (getOperation()->getNumResults() != 1)
690 return emitOpError("tensor 'dest' implies exactly one tensor result");
691 if (destType != getResult().getType())
692 return emitOpError("result and 'dest' types must match");
693 }
694 if (isa<BaseMemRefType>(getDest().getType()) &&
695 getOperation()->getNumResults() != 0)
696 return emitOpError("memref 'dest' implies zero results");
697 if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
698 return emitOpError("'restrict' is valid only for memref destinations");
699 if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
700 return emitOpError("'writable' must be specified if and only if the "
701 "destination is of memref type");
702 TensorType srcType = getSource().getType();
703 ShapedType destType = cast<ShapedType>(getDest().getType());
704 if (srcType.hasRank() != destType.hasRank())
705 return emitOpError("source/destination shapes are incompatible");
706 if (srcType.hasRank()) {
707 if (failed(verifyRanksMatch(getOperation(), srcType, destType, "source",
708 "destination")))
709 return failure();
710 for (auto [src, dest] :
711 llvm::zip(srcType.getShape(), destType.getShape())) {
712 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
713 // Cannot verify dynamic dimension size. Assume that that they match at
714 // runtime.
715 continue;
716 }
717 if (src != dest)
718 return emitOpError("source/destination shapes are incompatible");
719 }
720 }
721 return success();
722}
723
724void MaterializeInDestinationOp::build(OpBuilder &builder,
725 OperationState &state, Value source,
726 Value dest) {
727 auto destTensorType = dyn_cast<TensorType>(dest.getType());
728 build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
729 source, dest);
730}
731
732bool MaterializeInDestinationOp::isWritable(Value value,
733 const AnalysisState &state) {
734 return isa<TensorType>(getDest().getType()) ? true : getWritable();
735}
736
737MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
738 return getDestMutable();
739}
740
741void MaterializeInDestinationOp::getEffects(
743 &effects) {
744 if (isa<BaseMemRefType>(getDest().getType()))
745 effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
747}
748
749//===----------------------------------------------------------------------===//
750// ToTensorOp
751//===----------------------------------------------------------------------===//
752
753bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
754 return getWritable();
755}
756
757OpFoldResult ToTensorOp::fold(FoldAdaptor) {
758 if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
759 // Approximate alias analysis by conservatively folding only when no there
760 // is no interleaved operation.
761 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
762 toBuffer->getNextNode() == this->getOperation())
763 return toBuffer.getTensor();
764 return {};
765}
766
767namespace {
768struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
769 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
770
771 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
772 PatternRewriter &rewriter) const override {
773 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
774 if (!memrefToTensorOp)
775 return failure();
776
777 rewriter.replaceOpWithNewOp<memref::DimOp>(
778 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
779 return success();
780 }
781};
782} // namespace
783
784void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
785 MLIRContext *context) {
786 results.add<DimOfToTensorFolder>(context);
787}
788
789//===----------------------------------------------------------------------===//
790// ToBufferOp
791//===----------------------------------------------------------------------===//
792
793OpFoldResult ToBufferOp::fold(FoldAdaptor) {
794 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
795 if (memrefToTensor.getBuffer().getType() == getType())
796 return memrefToTensor.getBuffer();
797 return {};
798}
799
800namespace {
801
802/// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
803struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
804 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
805
806 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
807 PatternRewriter &rewriter) const final {
808 auto tensorCastOperand =
809 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
810 if (!tensorCastOperand)
811 return failure();
812 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
813 tensorCastOperand.getOperand().getType());
814 if (!srcTensorType)
815 return failure();
816 auto currentOutputMemRefType =
817 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
818 if (!currentOutputMemRefType)
819 return failure();
820
821 auto memrefType = currentOutputMemRefType.cloneWith(
822 srcTensorType.getShape(), srcTensorType.getElementType());
823 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
824 tensorCastOperand.getOperand(),
825 toBuffer.getReadOnly());
826 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
827 memref);
828 return success();
829 }
830};
831
832/// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
833/// cast if necessary.
834struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
835 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
836
837 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
838 PatternRewriter &rewriter) const final {
839 BufferizationOptions options;
840 options.bufferAlignment = 0;
841 return foldToBufferToTensorPair(rewriter, toBuffer, options);
842 }
843};
844
845/// Fold a load on a to_buffer operation into an tensor.extract on the
846/// corresponding tensor.
847struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
848 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
849
850 LogicalResult matchAndRewrite(memref::LoadOp load,
851 PatternRewriter &rewriter) const override {
852 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
853 if (!toBuffer || !toBuffer.getReadOnly())
854 return failure();
855
856 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
857 load.getIndices());
858 return success();
859 }
860};
861
862/// Fold dim of a to_buffer into the dim of the tensor.
863struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
864 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
865
866 LogicalResult matchAndRewrite(memref::DimOp dimOp,
867 PatternRewriter &rewriter) const override {
868 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
869 if (!castOp)
870 return failure();
871 Value newSource = castOp.getOperand();
872 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
873 dimOp.getIndex());
874 return success();
875 }
876};
877
878} // namespace
879
880void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
881 MLIRContext *context) {
882 results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
883 ToBufferToTensorFolding>(context);
884}
885
886LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
888 BufferizationState &state) {
889 // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
890 (void)foldToBufferToTensorPair(rewriter, *this, options);
891 // Note: The return value of `bufferize` indicates whether there was an error
892 // or not. (And not whether the pattern matched or not.)
893 return success();
894}
895
896std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
897 Value alloc) {
898 return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
899 .getOperation();
900}
901
902std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
903 return CloneOp::create(builder, alloc.getLoc(), alloc).getResult();
904}
905
906//===----------------------------------------------------------------------===//
907// DeallocOp
908//===----------------------------------------------------------------------===//
909
910LogicalResult DeallocOp::inferReturnTypes(
911 MLIRContext *context, std::optional<::mlir::Location> location,
912 ValueRange operands, DictionaryAttr attributes, PropertyRef properties,
913 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
914 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
915 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
916 IntegerType::get(context, 1));
917 return success();
918}
919
920LogicalResult DeallocOp::verify() {
921 if (getMemrefs().size() != getConditions().size())
922 return emitOpError(
923 "must have the same number of conditions as memrefs to deallocate");
924 if (getRetained().size() != getUpdatedConditions().size())
925 return emitOpError("must have the same number of updated conditions "
926 "(results) as retained operands");
927 return success();
928}
929
930static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
931 ValueRange memrefs,
932 ValueRange conditions,
933 PatternRewriter &rewriter) {
934 if (deallocOp.getMemrefs() == memrefs &&
935 deallocOp.getConditions() == conditions)
936 return failure();
937
938 rewriter.modifyOpInPlace(deallocOp, [&]() {
939 deallocOp.getMemrefsMutable().assign(memrefs);
940 deallocOp.getConditionsMutable().assign(conditions);
941 });
942 return success();
943}
944
945namespace {
946
947/// Remove duplicate values in the list of memrefs to be deallocated. We need to
948/// make sure the corresponding condition value is updated accordingly since
949/// their two conditions might not cover the same set of cases. In that case, we
950/// have to combine them (by computing the disjunction of them).
951/// Example:
952/// ```mlir
953/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
954/// ```
955/// is canonicalized to
956/// ```mlir
957/// %0 = arith.ori %arg1, %arg2 : i1
958/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
959/// ```
960struct DeallocRemoveDuplicateDeallocMemrefs
961 : public OpRewritePattern<DeallocOp> {
962 using OpRewritePattern<DeallocOp>::OpRewritePattern;
963
964 LogicalResult matchAndRewrite(DeallocOp deallocOp,
965 PatternRewriter &rewriter) const override {
966 // Unique memrefs to be deallocated.
967 DenseMap<Value, unsigned> memrefToCondition;
968 SmallVector<Value> newMemrefs, newConditions;
969 for (auto [i, memref, cond] :
970 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
971 if (memrefToCondition.count(memref)) {
972 // If the dealloc conditions don't match, we need to make sure that the
973 // dealloc happens on the union of cases.
974 Value &newCond = newConditions[memrefToCondition[memref]];
975 if (newCond != cond)
976 newCond =
977 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
978 } else {
979 memrefToCondition.insert({memref, newConditions.size()});
980 newMemrefs.push_back(memref);
981 newConditions.push_back(cond);
982 }
983 }
984
985 // Return failure if we don't change anything such that we don't run into an
986 // infinite loop of pattern applications.
987 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
988 rewriter);
989 }
990};
991
992/// Remove duplicate values in the list of retained memrefs. We need to make
993/// sure the corresponding result condition value is replaced properly.
994/// Example:
995/// ```mlir
996/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
997/// ```
998/// is canonicalized to
999/// ```mlir
1000/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
1001/// ```
1002struct DeallocRemoveDuplicateRetainedMemrefs
1003 : public OpRewritePattern<DeallocOp> {
1004 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1005
1006 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1007 PatternRewriter &rewriter) const override {
1008 // Unique retained values
1010 SmallVector<Value> newRetained;
1011 SmallVector<unsigned> resultReplacementIdx;
1012 unsigned i = 0;
1013 for (auto retained : deallocOp.getRetained()) {
1014 if (seen.count(retained)) {
1015 resultReplacementIdx.push_back(seen[retained]);
1016 continue;
1017 }
1018
1019 seen[retained] = i;
1020 newRetained.push_back(retained);
1021 resultReplacementIdx.push_back(i++);
1022 }
1023
1024 // Return failure if we don't change anything such that we don't run into an
1025 // infinite loop of pattern applications.
1026 if (newRetained.size() == deallocOp.getRetained().size())
1027 return failure();
1028
1029 // We need to create a new op because the number of results is always the
1030 // same as the number of condition operands.
1031 auto newDeallocOp =
1032 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1033 deallocOp.getConditions(), newRetained);
1034 SmallVector<Value> replacements(
1035 llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1036 return newDeallocOp.getUpdatedConditions()[idx];
1037 }));
1038 rewriter.replaceOp(deallocOp, replacements);
1039 return success();
1040 }
1041};
1042
1043/// Erase deallocation operations where the variadic list of memrefs to
1044/// deallocate is empty. Example:
1045/// ```mlir
1046/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1047/// ```
1048struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1049 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1050
1051 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1052 PatternRewriter &rewriter) const override {
1053 if (deallocOp.getMemrefs().empty()) {
1054 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1055 rewriter.getBoolAttr(false));
1056 rewriter.replaceOp(
1057 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1058 constFalse));
1059 return success();
1060 }
1061 return failure();
1062 }
1063};
1064
1065/// Removes memrefs from the deallocation list if their associated condition is
1066/// always 'false'.
1067///
1068/// Example:
1069/// ```
1070/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1071/// if (%arg2, %false)
1072/// ```
1073/// becomes
1074/// ```
1075/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1076/// ```
1077struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1078 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1079
1080 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1081 PatternRewriter &rewriter) const override {
1082 SmallVector<Value> newMemrefs, newConditions;
1083 for (auto [memref, cond] :
1084 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1085 if (!matchPattern(cond, m_Zero())) {
1086 newMemrefs.push_back(memref);
1087 newConditions.push_back(cond);
1088 }
1089 }
1090
1091 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1092 rewriter);
1093 }
1094};
1095
1096/// The `memref.extract_strided_metadata` is often inserted to get the base
1097/// memref if the operand is not already guaranteed to be the result of a memref
1098/// allocation operation. This canonicalization pattern removes this extraction
1099/// operation if the operand is now produced by an allocation operation (e.g.,
1100/// due to other canonicalizations simplifying the IR).
1101///
1102/// Example:
1103/// ```mlir
1104/// %alloc = memref.alloc() : memref<2xi32>
1105/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1106/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1107/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1108/// ```
1109/// is canonicalized to
1110/// ```mlir
1111/// %alloc = memref.alloc() : memref<2xi32>
1112/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1113/// ```
1114struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1115 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1116
1117 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1118 PatternRewriter &rewriter) const override {
1119 SmallVector<Value> newMemrefs(
1120 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1121 auto extractStridedOp =
1122 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1123 if (!extractStridedOp)
1124 return memref;
1125 Value allocMemref = extractStridedOp.getOperand();
1126 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1127 if (!allocOp)
1128 return memref;
1129 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1130 return allocMemref;
1131 return memref;
1132 }));
1133
1134 return updateDeallocIfChanged(deallocOp, newMemrefs,
1135 deallocOp.getConditions(), rewriter);
1136 }
1137};
1138
1139/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1140/// other user of the allocated value and the allocating operation can be safely
1141/// removed. If the same value is present multiple times, this pattern relies on
1142/// other canonicalization patterns to remove the duplicate first.
1143///
1144/// Example:
1145/// ```mlir
1146/// %alloc = memref.alloc() : memref<2xi32>
1147/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1148/// ```
1149/// is canonicalized to
1150/// ```mlir
1151/// bufferization.dealloc (%arg0 : ...) if (%true)
1152/// ```
1153struct RemoveAllocDeallocPairWhenNoOtherUsers
1154 : public OpRewritePattern<DeallocOp> {
1155 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1156
1157 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1158 PatternRewriter &rewriter) const override {
1159 SmallVector<Value> newMemrefs, newConditions;
1160 SmallVector<Operation *> toDelete;
1161 for (auto [memref, cond] :
1162 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1163 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1164 // Check that it is indeed an allocate effect, that the op has no other
1165 // side effects (which would not allow us to remove the op), and that
1166 // there are no other users.
1167 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1169 memref.hasOneUse()) {
1170 toDelete.push_back(allocOp);
1171 continue;
1172 }
1173 }
1174
1175 newMemrefs.push_back(memref);
1176 newConditions.push_back(cond);
1177 }
1178
1179 if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1180 rewriter)))
1181 return failure();
1182
1183 for (Operation *op : toDelete)
1184 rewriter.eraseOp(op);
1185
1186 return success();
1187 }
1188};
1189
1190} // anonymous namespace
1191
1192void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1193 MLIRContext *context) {
1195}
1196
1198 RewritePatternSet &patterns, MLIRContext *context) {
1199 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1200 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1201 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1202 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1203}
1204
1205//===----------------------------------------------------------------------===//
1206// TableGen'd op method definitions
1207//===----------------------------------------------------------------------===//
1208
1209#define GET_OP_CLASSES
1210#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:112
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:104
IndexType getIndexType()
Definition Builders.cpp:55
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:254
unsigned getOperandNumber() const
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:454
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Type-safe wrapper around a void* for passing properties, including the properties structs of operatio...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:357
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:69
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:120
llvm::function_ref< Fn > function_ref
Definition LLVM.h:147
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.