MLIR 23.0.0git
BufferizationOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Matchers.h"
17#include "llvm/ADT/SmallVectorExtras.h"
18#include <optional>
19
20using namespace mlir;
21using namespace mlir::bufferization;
22
23//===----------------------------------------------------------------------===//
24// Helper functions
25//===----------------------------------------------------------------------===//
26
28 OpBuilder &b, Value value, MemRefType destType,
30 auto srcType = llvm::cast<MemRefType>(value.getType());
31
32 // Element type and rank must match.
33 if (srcType.getElementType() != destType.getElementType())
34 return failure();
35 if (srcType.getRank() != destType.getRank())
36 return failure();
37
38 // In case the affine maps are different, we may need to use a copy if we go
39 // from dynamic to static offset or stride (the canonicalization cannot know
40 // at this point that it is really cast compatible).
41 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
42 int64_t sourceOffset, targetOffset;
43 SmallVector<int64_t, 4> sourceStrides, targetStrides;
44 if (failed(source.getStridesAndOffset(sourceStrides, sourceOffset)) ||
45 failed(target.getStridesAndOffset(targetStrides, targetOffset)))
46 return false;
47 auto dynamicToStatic = [](int64_t a, int64_t b) {
48 return ShapedType::isDynamic(a) && ShapedType::isStatic(b);
49 };
50 if (dynamicToStatic(sourceOffset, targetOffset))
51 return false;
52 for (auto it : zip(sourceStrides, targetStrides))
53 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
54 return false;
55 return true;
56 };
57
58 // Note: If `areCastCompatible`, a cast is valid, but may fail at runtime. To
59 // ensure that we only generate casts that always succeed at runtime, we check
60 // a fix extra conditions in `isGuaranteedCastCompatible`.
61 if (memref::CastOp::areCastCompatible(srcType, destType) &&
62 isGuaranteedCastCompatible(srcType, destType)) {
63 Value casted = memref::CastOp::create(b, value.getLoc(), destType, value);
64 return casted;
65 }
66
67 auto loc = value.getLoc();
68 SmallVector<Value, 4> dynamicOperands;
69 for (int i = 0; i < destType.getRank(); ++i) {
70 if (destType.getShape()[i] != ShapedType::kDynamic)
71 continue;
72 Value size = memref::DimOp::create(b, loc, value, i);
73 dynamicOperands.push_back(size);
74 }
75
76 FailureOr<Value> copy =
77 options.createAlloc(b, loc, destType, dynamicOperands);
78 if (failed(copy))
79 return failure();
80 if (failed(options.createMemCpy(b, loc, value, *copy)))
81 return failure();
82 return copy;
83}
84
85/// Try to fold to_buffer(to_tensor(x)). If x's type and the result type of the
86/// to_buffer op are different, a memref.cast is needed.
88 RewriterBase &rewriter, ToBufferOp toBuffer,
90 auto bufferToTensor = toBuffer.getTensor().getDefiningOp<ToTensorOp>();
91 if (!bufferToTensor)
92 return failure();
93
94 Type srcType = bufferToTensor.getBuffer().getType();
95 Type destType = toBuffer.getType();
96
97 // Directly rewrite if the type did not change.
98 if (srcType == destType) {
99 rewriter.replaceOp(toBuffer, bufferToTensor.getBuffer());
100 return success();
101 }
102
103 auto rankedSrcType = llvm::dyn_cast<MemRefType>(srcType);
104 auto rankedDestType = llvm::dyn_cast<MemRefType>(destType);
105 auto unrankedSrcType = llvm::dyn_cast<UnrankedMemRefType>(srcType);
106
107 // Ranked memref -> Ranked memref cast.
108 if (rankedSrcType && rankedDestType) {
109 FailureOr<Value> replacement = castOrReallocMemRefValue(
110 rewriter, bufferToTensor.getBuffer(), rankedDestType, options);
111 if (failed(replacement))
112 return failure();
113
114 rewriter.replaceOp(toBuffer, *replacement);
115 return success();
116 }
117
118 // Unranked memref -> Ranked memref cast: May require a copy.
119 // TODO: Not implemented at the moment.
120 if (unrankedSrcType && rankedDestType)
121 return failure();
122
123 // Unranked memref -> unranked memref cast
124 // Ranked memref -> unranked memref cast: No copy needed.
125 assert(memref::CastOp::areCastCompatible(srcType, destType) &&
126 "expected that types are cast compatible");
127 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, destType,
128 bufferToTensor.getBuffer());
129 return success();
130}
131
133 OpBuilder &b, Location loc, Value shapedValue,
134 SmallVector<Value> &dynamicDims) {
135 auto shapedType = llvm::cast<ShapedType>(shapedValue.getType());
136 for (int64_t i = 0; i < shapedType.getRank(); ++i) {
137 if (shapedType.isDynamicDim(i)) {
138 if (llvm::isa<MemRefType>(shapedType)) {
139 dynamicDims.push_back(memref::DimOp::create(b, loc, shapedValue, i));
140 } else {
141 assert(llvm::isa<RankedTensorType>(shapedType) && "expected tensor");
142 dynamicDims.push_back(tensor::DimOp::create(b, loc, shapedValue, i));
143 }
144 }
145 }
146}
147
148//===----------------------------------------------------------------------===//
149// AllocTensorOp
150//===----------------------------------------------------------------------===//
151
152LogicalResult AllocTensorOp::bufferize(RewriterBase &rewriter,
154 BufferizationState &state) {
155 OpBuilder::InsertionGuard g(rewriter);
156 Location loc = getLoc();
157
158 // Nothing to do for dead AllocTensorOps.
159 if (getOperation()->getUses().empty()) {
160 rewriter.eraseOp(getOperation());
161 return success();
162 }
163
164 // Get "copy" buffer.
165 Value copyBuffer;
166 if (getCopy()) {
167 FailureOr<Value> maybeCopyBuffer =
168 getBuffer(rewriter, getCopy(), options, state);
169 if (failed(maybeCopyBuffer))
170 return failure();
171 copyBuffer = *maybeCopyBuffer;
172 }
173
174 // Create memory allocation.
175 auto allocType = bufferization::getBufferType(getResult(), options, state);
176 if (failed(allocType))
177 return failure();
178 SmallVector<Value> dynamicDims = getDynamicSizes();
179 if (getCopy()) {
180 assert(dynamicDims.empty() && "expected either `copy` or `dynamicDims`");
181 populateDynamicDimSizes(rewriter, loc, copyBuffer, dynamicDims);
182 }
183 FailureOr<Value> alloc = options.createAlloc(
184 rewriter, loc, llvm::cast<MemRefType>(*allocType), dynamicDims);
185 if (failed(alloc))
186 return failure();
187
188 // Create memory copy (if any).
189 if (getCopy()) {
190 if (failed(options.createMemCpy(rewriter, loc, copyBuffer, *alloc)))
191 return failure();
192 }
193
194 // Replace op.
195 replaceOpWithBufferizedValues(rewriter, getOperation(), *alloc);
196
197 return success();
198}
199
200bool AllocTensorOp::resultBufferizesToMemoryWrite(OpResult opResult,
201 const AnalysisState &state) {
202 // AllocTensorOps do not write unless they have a `copy` value.
203 return static_cast<bool>(getCopy());
204}
205
206bool AllocTensorOp::bufferizesToMemoryRead(OpOperand &opOperand,
207 const AnalysisState &state) {
208 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
209 "expected copy operand");
210 return true;
211}
212
213bool AllocTensorOp::bufferizesToMemoryWrite(OpOperand &opOperand,
214 const AnalysisState &state) {
215 assert(opOperand.getOperandNumber() == getNumOperands() - 1 &&
216 "expected copy operand");
217 return false;
218}
219
220AliasingValueList AllocTensorOp::getAliasingValues(OpOperand &opOperand,
221 const AnalysisState &state) {
222 // This is a new allocation. It does not alias with any other buffer.
223 return {};
224}
225
226FailureOr<BufferLikeType>
227AllocTensorOp::getBufferType(Value value, const BufferizationOptions &options,
228 const BufferizationState &state,
229 SmallVector<Value> &invocationStack) {
230 assert(value == getResult() && "invalid value");
231
232 // Compute memory space of this allocation.
233 Attribute memorySpace;
234 if (getMemorySpace().has_value()) {
235 memorySpace = *getMemorySpace();
236 } else if (getCopy()) {
237 auto copyBufferType =
238 bufferization::detail::asMemRefType(bufferization::getBufferType(
239 getCopy(), options, state, invocationStack));
240 if (failed(copyBufferType))
241 return failure();
242 memorySpace = copyBufferType->getMemorySpace();
243 } else if (auto ms = options.defaultMemorySpaceFn(getType())) {
244 memorySpace = *ms;
245 } else {
246 return getOperation()->emitError("could not infer memory space");
247 }
248
249 return cast<BufferLikeType>(
250 getMemRefTypeWithStaticIdentityLayout(getType(), memorySpace));
251}
252
253LogicalResult AllocTensorOp::verify() {
254 if (getCopy() && !getDynamicSizes().empty())
255 return emitError("dynamic sizes not needed when copying a tensor");
256 if (!getCopy() && failed(verifyDynamicDimensionCount(
257 getOperation(), getType(), getDynamicSizes())))
258 return failure();
259 if (getCopy() && getCopy().getType() != getType())
260 return emitError("expected that `copy` and return type match");
261 return success();
262}
263
264void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
265 RankedTensorType type, ValueRange dynamicSizes) {
266 build(builder, result, type, dynamicSizes, /*copy=*/Value(),
267 /*size_hint=*/Value(),
268 /*memory_space=*/IntegerAttr());
269}
270
271void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
272 RankedTensorType type, ValueRange dynamicSizes,
273 Value copy) {
274 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
275 /*memory_space=*/IntegerAttr());
276}
277
278void AllocTensorOp::build(OpBuilder &builder, OperationState &result,
279 TensorType type, ValueRange dynamicSizes, Value copy,
280 IntegerAttr memorySpace) {
281 build(builder, result, type, dynamicSizes, copy, /*size_hint=*/Value(),
282 memorySpace);
283}
284
285namespace {
286/// Change the type of the result of a `bufferization.alloc_tensor` by making
287/// the result type statically sized along dimension that in the original
288/// operation where defined as dynamic, but the size was defined using a
289/// `constant` op. For example:
290///
291/// %c5 = arith.constant 5: index
292/// %0 = bufferization.alloc_tensor(%arg0, %c5) : tensor<?x?xf32>
293///
294/// to
295///
296/// %0 = bufferization.alloc_tensor(%arg0) : tensor<?x5xf32>
297struct ReplaceStaticShapeDims : OpRewritePattern<AllocTensorOp> {
298 using OpRewritePattern<AllocTensorOp>::OpRewritePattern;
299
300 LogicalResult matchAndRewrite(AllocTensorOp op,
301 PatternRewriter &rewriter) const override {
302 if (op.getCopy())
303 return failure();
304 SmallVector<int64_t> newShape = llvm::to_vector(op.getType().getShape());
305 SmallVector<Value> newDynamicSizes;
306 unsigned int dynValCounter = 0;
307 for (int64_t i = 0; i < op.getType().getRank(); ++i) {
308 if (!op.isDynamicDim(i))
309 continue;
310 Value value = op.getDynamicSizes()[dynValCounter++];
311 APInt intVal;
312 if (matchPattern(value, m_ConstantInt(&intVal))) {
313 int64_t dim = intVal.getSExtValue();
314 if (dim >= 0)
315 newShape[i] = intVal.getSExtValue();
316 else
317 newDynamicSizes.push_back(value);
318 } else {
319 newDynamicSizes.push_back(value);
320 }
321 }
322 RankedTensorType newType = RankedTensorType::get(
323 newShape, op.getType().getElementType(), op.getType().getEncoding());
324 if (newType == op.getType())
325 return failure();
326 auto newOp = AllocTensorOp::create(rewriter, op.getLoc(), newType,
327 newDynamicSizes, /*copy=*/Value());
328 rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), newOp);
329 return success();
330 }
331};
332
333struct FoldDimOfAllocTensorOp : public OpRewritePattern<tensor::DimOp> {
334 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
335
336 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
337 PatternRewriter &rewriter) const override {
338 std::optional<int64_t> maybeConstantIndex = dimOp.getConstantIndex();
339 auto allocTensorOp = dimOp.getSource().getDefiningOp<AllocTensorOp>();
340 if (!allocTensorOp || !maybeConstantIndex)
341 return failure();
342 if (*maybeConstantIndex < 0 ||
343 *maybeConstantIndex >= allocTensorOp.getType().getRank())
344 return failure();
345 if (!allocTensorOp.getType().isDynamicDim(*maybeConstantIndex))
346 return failure();
347 rewriter.replaceOp(
348 dimOp, allocTensorOp.getDynamicSize(rewriter, *maybeConstantIndex));
349 return success();
350 }
351};
352} // namespace
353
354void AllocTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
355 MLIRContext *ctx) {
356 results.add<FoldDimOfAllocTensorOp, ReplaceStaticShapeDims>(ctx);
357}
358
359LogicalResult AllocTensorOp::reifyResultShapes(
360 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
361 auto shapes =
362 llvm::map_to_vector<4>(llvm::seq<int64_t>(0, getType().getRank()),
363 [&](int64_t dim) -> OpFoldResult {
364 if (isDynamicDim(dim))
365 return getDynamicSize(builder, dim);
366 return builder.getIndexAttr(getStaticSize(dim));
367 });
368 reifiedReturnShapes.emplace_back(std::move(shapes));
369 return success();
370}
371
372ParseResult AllocTensorOp::parse(OpAsmParser &parser, OperationState &result) {
374 if (parser.parseLParen() || parser.parseOperandList(dynamicSizesOperands) ||
375 parser.parseRParen())
376 return failure();
377 ParseResult copyKeyword = parser.parseOptionalKeyword("copy");
379 if (copyKeyword.succeeded())
380 if (parser.parseLParen() || parser.parseOperand(copyOperand) ||
381 parser.parseRParen())
382 return failure();
383 ParseResult sizeHintKeyword = parser.parseOptionalKeyword("size_hint");
384 OpAsmParser::UnresolvedOperand sizeHintOperand;
385 if (sizeHintKeyword.succeeded())
386 if (parser.parseEqual() || parser.parseOperand(sizeHintOperand))
387 return failure();
388 if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColon())
389 return failure();
390
391 TensorType type;
392 if (parser.parseCustomTypeWithFallback(type))
393 return failure();
394 result.addTypes(type);
395
396 Type indexType = parser.getBuilder().getIndexType();
397 if (parser.resolveOperands(dynamicSizesOperands, indexType, result.operands))
398 return failure();
399 if (copyKeyword.succeeded())
400 if (parser.resolveOperand(copyOperand, type, result.operands))
401 return failure();
402 if (sizeHintKeyword.succeeded())
403 if (parser.resolveOperand(sizeHintOperand, indexType, result.operands))
404 return failure();
405 result.addAttribute(AllocTensorOp::getOperandSegmentSizeAttr(),
407 {static_cast<int32_t>(dynamicSizesOperands.size()),
408 static_cast<int32_t>(copyKeyword.succeeded()),
409 static_cast<int32_t>(sizeHintKeyword.succeeded())}));
410 return success();
411}
412
413void AllocTensorOp::print(OpAsmPrinter &p) {
414 p << "(" << getDynamicSizes() << ")";
415 if (getCopy())
416 p << " copy(" << getCopy() << ")";
417 if (getSizeHint())
418 p << " size_hint=" << getSizeHint();
419 p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{
420 AllocTensorOp::getOperandSegmentSizeAttr()});
421 p << " : ";
422 auto type = getResult().getType();
423 if (auto validType = llvm::dyn_cast<::mlir::TensorType>(type))
424 p.printStrippedAttrOrType(validType);
425 else
426 p << type;
427}
428
429Value AllocTensorOp::getDynamicSize(OpBuilder &b, unsigned idx) {
430 assert(isDynamicDim(idx) && "expected dynamic dim");
431 if (getCopy())
432 return tensor::DimOp::create(b, getLoc(), getCopy(), idx);
433 return getOperand(getIndexOfDynamicSize(idx));
434}
435
436//===----------------------------------------------------------------------===//
437// CloneOp
438//===----------------------------------------------------------------------===//
439
440OpFoldResult CloneOp::fold(FoldAdaptor adaptor) {
441 return succeeded(memref::foldMemRefCast(*this)) ? getResult() : Value();
442}
443
444namespace {
445
446/// Merge the clone and its source (by converting the clone to a cast) when
447/// possible.
448struct SimplifyClones : public OpRewritePattern<CloneOp> {
449 using OpRewritePattern<CloneOp>::OpRewritePattern;
450
451 LogicalResult matchAndRewrite(CloneOp cloneOp,
452 PatternRewriter &rewriter) const override {
453 if (cloneOp.use_empty()) {
454 rewriter.eraseOp(cloneOp);
455 return success();
456 }
457
458 Value source = cloneOp.getInput();
459 if (source.getType() != cloneOp.getType() &&
460 !memref::CastOp::areCastCompatible({source.getType()},
461 {cloneOp.getType()}))
462 return failure();
463
464 // Aims to find the dealloc op for the canonical source
465 // which otherwise could prevent removal of unnecessary allocs.
466 Value canonicalSource = source;
467 while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
468 canonicalSource.getDefiningOp())) {
469 if (canonicalSource != iface.getViewDest()) {
470 break;
471 }
472 canonicalSource = iface.getViewSource();
473 }
474
475 std::optional<Operation *> maybeCloneDeallocOp =
476 memref::findDealloc(cloneOp.getOutput());
477 // Skip if either of them has > 1 deallocate operations.
478 if (!maybeCloneDeallocOp.has_value())
479 return failure();
480 std::optional<Operation *> maybeSourceDeallocOp =
481 memref::findDealloc(canonicalSource);
482 if (!maybeSourceDeallocOp.has_value())
483 return failure();
484 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
485 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
486
487 // If both are deallocated in the same block, their in-block lifetimes
488 // might not fully overlap, so we cannot decide which one to drop.
489 if (cloneDeallocOp && sourceDeallocOp &&
490 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
491 return failure();
492
493 Block *currentBlock = cloneOp->getBlock();
494 Operation *redundantDealloc = nullptr;
495 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
496 redundantDealloc = cloneDeallocOp;
497 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
498 redundantDealloc = sourceDeallocOp;
499 }
500
501 if (!redundantDealloc)
502 return failure();
503
504 // Safety check that there are no other deallocations inbetween
505 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
506 // of source before the uses of the clone. With alias information, we could
507 // restrict this to only fail of the dealloc's operand is an alias
508 // of the source.
509 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
510 pos = pos->getNextNode()) {
511 // Bail if we run out of operations while looking for a deallocation op.
512 if (!pos)
513 return failure();
514 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
515 if (!effectInterface)
516 continue;
517 if (effectInterface.hasEffect<MemoryEffects::Free>())
518 return failure();
519 }
520
521 if (source.getType() != cloneOp.getType())
522 source = memref::CastOp::create(rewriter, cloneOp.getLoc(),
523 cloneOp.getType(), source);
524 rewriter.replaceOp(cloneOp, source);
525 rewriter.eraseOp(redundantDealloc);
526 return success();
527 }
528};
529
530} // namespace
531
532void CloneOp::getCanonicalizationPatterns(RewritePatternSet &results,
533 MLIRContext *context) {
534 results.add<SimplifyClones>(context);
535}
536
537//===----------------------------------------------------------------------===//
538// DeallocTensorOp
539//===----------------------------------------------------------------------===//
540
541LogicalResult DeallocTensorOp::bufferize(RewriterBase &rewriter,
543 BufferizationState &state) {
544 FailureOr<Value> buffer = getBuffer(rewriter, getTensor(), options, state);
545 if (failed(buffer))
546 return failure();
547 memref::DeallocOp::create(rewriter, getLoc(), *buffer);
548 rewriter.eraseOp(getOperation());
549 return success();
550}
551
552//===----------------------------------------------------------------------===//
553// MaterializeInDestinationOp
554//===----------------------------------------------------------------------===//
555
556bool MaterializeInDestinationOp::bufferizesToMemoryRead(
557 OpOperand &opOperand, const AnalysisState &state) {
558 return opOperand == getSourceMutable();
559}
560
561bool MaterializeInDestinationOp::bufferizesToMemoryWrite(
562 OpOperand &opOperand, const AnalysisState &state) {
563 if (opOperand == getDestMutable()) {
564 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
565 return true;
566 }
567 return false;
568}
569
570bool MaterializeInDestinationOp::mustBufferizeInPlace(
571 OpOperand &opOperand, const AnalysisState &state) {
572 // The source is only read and not written, so it always bufferizes in-place
573 // by default. The destination is written and is forced to bufferize in-place
574 // (if it is a tensor).
575 return true;
576}
577
578AliasingValueList
579MaterializeInDestinationOp::getAliasingValues(OpOperand &opOperand,
580 const AnalysisState &state) {
581 if (opOperand == getDestMutable()) {
582 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
583 return {{getOperation()->getResult(0), BufferRelation::Equivalent}};
584 }
585 return {};
586}
587
588LogicalResult
589MaterializeInDestinationOp::bufferize(RewriterBase &rewriter,
591 BufferizationState &state) {
592 bool tensorDest = isa<TensorType>(getDest().getType());
593 Value buffer;
594 if (tensorDest) {
595 FailureOr<Value> maybeBuffer =
596 getBuffer(rewriter, getDest(), options, state);
597 if (failed(maybeBuffer))
598 return failure();
599 buffer = *maybeBuffer;
600 } else {
601 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
602 buffer = getDest();
603 }
604 auto srcBuffer = getBuffer(rewriter, getSource(), options, state);
605 if (failed(srcBuffer))
606 return failure();
607 if (failed(options.createMemCpy(rewriter, getLoc(), *srcBuffer, buffer)))
608 return failure();
609 replaceOpWithBufferizedValues(rewriter, getOperation(),
610 tensorDest ? ValueRange(buffer) : ValueRange());
611 return success();
612}
613
614bool MaterializeInDestinationOp::bufferizesToElementwiseAccess(
615 const AnalysisState &state, ArrayRef<OpOperand *> opOperands) {
616 // As elements are copied from the "source" buffer to the "dest" buffer,
617 // already copied elements are not read a second time.
618 return true;
619}
620
621LogicalResult MaterializeInDestinationOp::reifyResultShapes(
622 OpBuilder &builder, ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
623 if (getOperation()->getNumResults() == 1) {
624 assert(isa<TensorType>(getDest().getType()) && "expected tensor type");
625 reifiedReturnShapes.resize(1,
627 reifiedReturnShapes[0] =
628 tensor::getMixedSizes(builder, getLoc(), getDest());
629 }
630 return success();
631}
632
633Value MaterializeInDestinationOp::buildSubsetExtraction(OpBuilder &builder,
634 Location loc) {
635 if (isa<TensorType>(getDest().getType())) {
636 // The subset is the entire destination tensor.
637 return getDest();
638 }
639
640 // The "restrict" attribute is transferred from this op to the newly created
641 // to_tensor op. If this op does not the "restrict" attribute, the subset
642 // extraction cannot be built because there is no guarantee that there is no
643 // pre-existing "restrict" to_tensor op with the same/an aliasing destination.
644 if (!getRestrict())
645 return {};
646
647 // Build a bufferization.to_tensor op.
648 assert(isa<BaseMemRefType>(getDest().getType()) && "expected memref type");
649 assert(getRestrict() &&
650 "expected that ops with memrefs dest have 'restrict'");
651 setRestrict(false);
652 return ToTensorOp::create(
653 builder, loc, memref::getTensorTypeFromMemRefType(getDest().getType()),
654 getDest(),
655 /*restrict=*/true, getWritable());
656}
657
658bool MaterializeInDestinationOp::isEquivalentSubset(
659 Value candidate, function_ref<bool(Value, Value)> equivalenceFn) {
660 return equivalenceFn(getDest(), candidate);
661}
662
664MaterializeInDestinationOp::getValuesNeededToBuildSubsetExtraction() {
665 return {getDest()};
666}
667
668OpOperand &MaterializeInDestinationOp::getSourceOperand() {
669 return getOperation()->getOpOperand(0) /*source*/;
670}
671
672bool MaterializeInDestinationOp::operatesOnEquivalentSubset(
673 SubsetOpInterface subsetOp,
674 function_ref<bool(Value, Value)> equivalenceFn) {
675 return false;
676}
677
678bool MaterializeInDestinationOp::operatesOnDisjointSubset(
679 SubsetOpInterface subsetOp,
680 function_ref<bool(Value, Value)> equivalenceFn) {
681 return false;
682}
683
684LogicalResult MaterializeInDestinationOp::verify() {
685 if (!isa<TensorType, BaseMemRefType>(getDest().getType()))
686 return emitOpError("'dest' must be a tensor or a memref");
687 if (auto destType = dyn_cast<TensorType>(getDest().getType())) {
688 if (getOperation()->getNumResults() != 1)
689 return emitOpError("tensor 'dest' implies exactly one tensor result");
690 if (destType != getResult().getType())
691 return emitOpError("result and 'dest' types must match");
692 }
693 if (isa<BaseMemRefType>(getDest().getType()) &&
694 getOperation()->getNumResults() != 0)
695 return emitOpError("memref 'dest' implies zero results");
696 if (getRestrict() && !isa<BaseMemRefType>(getDest().getType()))
697 return emitOpError("'restrict' is valid only for memref destinations");
698 if (getWritable() != isa<BaseMemRefType>(getDest().getType()))
699 return emitOpError("'writable' must be specified if and only if the "
700 "destination is of memref type");
701 TensorType srcType = getSource().getType();
702 ShapedType destType = cast<ShapedType>(getDest().getType());
703 if (srcType.hasRank() != destType.hasRank())
704 return emitOpError("source/destination shapes are incompatible");
705 if (srcType.hasRank()) {
706 if (failed(verifyRanksMatch(getOperation(), srcType, destType, "source",
707 "destination")))
708 return failure();
709 for (auto [src, dest] :
710 llvm::zip(srcType.getShape(), destType.getShape())) {
711 if (src == ShapedType::kDynamic || dest == ShapedType::kDynamic) {
712 // Cannot verify dynamic dimension size. Assume that that they match at
713 // runtime.
714 continue;
715 }
716 if (src != dest)
717 return emitOpError("source/destination shapes are incompatible");
718 }
719 }
720 return success();
721}
722
723void MaterializeInDestinationOp::build(OpBuilder &builder,
724 OperationState &state, Value source,
725 Value dest) {
726 auto destTensorType = dyn_cast<TensorType>(dest.getType());
727 build(builder, state, /*result=*/destTensorType ? destTensorType : Type(),
728 source, dest);
729}
730
731bool MaterializeInDestinationOp::isWritable(Value value,
732 const AnalysisState &state) {
733 return isa<TensorType>(getDest().getType()) ? true : getWritable();
734}
735
736MutableOperandRange MaterializeInDestinationOp::getDpsInitsMutable() {
737 return getDestMutable();
738}
739
740void MaterializeInDestinationOp::getEffects(
742 &effects) {
743 if (isa<BaseMemRefType>(getDest().getType()))
744 effects.emplace_back(MemoryEffects::Write::get(), &getDestMutable(),
746}
747
748//===----------------------------------------------------------------------===//
749// ToTensorOp
750//===----------------------------------------------------------------------===//
751
752bool ToTensorOp::isWritable(Value value, const AnalysisState &state) {
753 return getWritable();
754}
755
756OpFoldResult ToTensorOp::fold(FoldAdaptor) {
757 if (auto toBuffer = getBuffer().getDefiningOp<ToBufferOp>())
758 // Approximate alias analysis by conservatively folding only when no there
759 // is no interleaved operation.
760 if (toBuffer->getBlock() == this->getOperation()->getBlock() &&
761 toBuffer->getNextNode() == this->getOperation())
762 return toBuffer.getTensor();
763 return {};
764}
765
766namespace {
767struct DimOfToTensorFolder : public OpRewritePattern<tensor::DimOp> {
768 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
769
770 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
771 PatternRewriter &rewriter) const override {
772 auto memrefToTensorOp = dimOp.getSource().getDefiningOp<ToTensorOp>();
773 if (!memrefToTensorOp)
774 return failure();
775
776 rewriter.replaceOpWithNewOp<memref::DimOp>(
777 dimOp, memrefToTensorOp.getBuffer(), dimOp.getIndex());
778 return success();
779 }
780};
781} // namespace
782
783void ToTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
784 MLIRContext *context) {
785 results.add<DimOfToTensorFolder>(context);
786}
787
788//===----------------------------------------------------------------------===//
789// ToBufferOp
790//===----------------------------------------------------------------------===//
791
792OpFoldResult ToBufferOp::fold(FoldAdaptor) {
793 if (auto memrefToTensor = getTensor().getDefiningOp<ToTensorOp>())
794 if (memrefToTensor.getBuffer().getType() == getType())
795 return memrefToTensor.getBuffer();
796 return {};
797}
798
799namespace {
800
801/// Replace tensor.cast + to_buffer by to_buffer + memref.cast.
802struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
803 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
804
805 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
806 PatternRewriter &rewriter) const final {
807 auto tensorCastOperand =
808 toBuffer.getOperand().getDefiningOp<tensor::CastOp>();
809 if (!tensorCastOperand)
810 return failure();
811 auto srcTensorType = llvm::dyn_cast<RankedTensorType>(
812 tensorCastOperand.getOperand().getType());
813 if (!srcTensorType)
814 return failure();
815 auto currentOutputMemRefType =
816 dyn_cast<BaseMemRefType>(toBuffer.getResult().getType());
817 if (!currentOutputMemRefType)
818 return failure();
819
820 auto memrefType = currentOutputMemRefType.cloneWith(
821 srcTensorType.getShape(), srcTensorType.getElementType());
822 Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
823 tensorCastOperand.getOperand(),
824 toBuffer.getReadOnly());
825 rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
826 memref);
827 return success();
828 }
829};
830
831/// Canonicalize bufferization.to_tensor + bufferization.to_buffer. Insert a
832/// cast if necessary.
833struct ToBufferToTensorFolding : public OpRewritePattern<ToBufferOp> {
834 using OpRewritePattern<ToBufferOp>::OpRewritePattern;
835
836 LogicalResult matchAndRewrite(ToBufferOp toBuffer,
837 PatternRewriter &rewriter) const final {
838 BufferizationOptions options;
839 options.bufferAlignment = 0;
840 return foldToBufferToTensorPair(rewriter, toBuffer, options);
841 }
842};
843
844/// Fold a load on a to_buffer operation into an tensor.extract on the
845/// corresponding tensor.
846struct LoadOfToBuffer : public OpRewritePattern<memref::LoadOp> {
847 using OpRewritePattern<memref::LoadOp>::OpRewritePattern;
848
849 LogicalResult matchAndRewrite(memref::LoadOp load,
850 PatternRewriter &rewriter) const override {
851 auto toBuffer = load.getMemref().getDefiningOp<ToBufferOp>();
852 if (!toBuffer || !toBuffer.getReadOnly())
853 return failure();
854
855 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, toBuffer.getTensor(),
856 load.getIndices());
857 return success();
858 }
859};
860
861/// Fold dim of a to_buffer into the dim of the tensor.
862struct DimOfCastOp : public OpRewritePattern<memref::DimOp> {
863 using OpRewritePattern<memref::DimOp>::OpRewritePattern;
864
865 LogicalResult matchAndRewrite(memref::DimOp dimOp,
866 PatternRewriter &rewriter) const override {
867 auto castOp = dimOp.getSource().getDefiningOp<ToBufferOp>();
868 if (!castOp)
869 return failure();
870 Value newSource = castOp.getOperand();
871 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource,
872 dimOp.getIndex());
873 return success();
874 }
875};
876
877} // namespace
878
879void ToBufferOp::getCanonicalizationPatterns(RewritePatternSet &results,
880 MLIRContext *context) {
881 results.add<DimOfCastOp, LoadOfToBuffer, ToBufferOfCast,
882 ToBufferToTensorFolding>(context);
883}
884
885LogicalResult ToBufferOp::bufferize(RewriterBase &rewriter,
887 BufferizationState &state) {
888 // Fold to_buffer(to_tensor(x)) to x. Insert a cast if necessary.
889 (void)foldToBufferToTensorPair(rewriter, *this, options);
890 // Note: The return value of `bufferize` indicates whether there was an error
891 // or not. (And not whether the pattern matched or not.)
892 return success();
893}
894
895std::optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder,
896 Value alloc) {
897 return memref::DeallocOp::create(builder, alloc.getLoc(), alloc)
898 .getOperation();
899}
900
901std::optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
902 return CloneOp::create(builder, alloc.getLoc(), alloc).getResult();
903}
904
905//===----------------------------------------------------------------------===//
906// DeallocOp
907//===----------------------------------------------------------------------===//
908
909LogicalResult DeallocOp::inferReturnTypes(
910 MLIRContext *context, std::optional<::mlir::Location> location,
911 ValueRange operands, DictionaryAttr attributes, OpaqueProperties properties,
912 RegionRange regions, SmallVectorImpl<Type> &inferredReturnTypes) {
913 DeallocOpAdaptor adaptor(operands, attributes, properties, regions);
914 inferredReturnTypes = SmallVector<Type>(adaptor.getRetained().size(),
915 IntegerType::get(context, 1));
916 return success();
917}
918
919LogicalResult DeallocOp::verify() {
920 if (getMemrefs().size() != getConditions().size())
921 return emitOpError(
922 "must have the same number of conditions as memrefs to deallocate");
923 if (getRetained().size() != getUpdatedConditions().size())
924 return emitOpError("must have the same number of updated conditions "
925 "(results) as retained operands");
926 return success();
927}
928
929static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp,
930 ValueRange memrefs,
931 ValueRange conditions,
932 PatternRewriter &rewriter) {
933 if (deallocOp.getMemrefs() == memrefs &&
934 deallocOp.getConditions() == conditions)
935 return failure();
936
937 rewriter.modifyOpInPlace(deallocOp, [&]() {
938 deallocOp.getMemrefsMutable().assign(memrefs);
939 deallocOp.getConditionsMutable().assign(conditions);
940 });
941 return success();
942}
943
944namespace {
945
946/// Remove duplicate values in the list of memrefs to be deallocated. We need to
947/// make sure the corresponding condition value is updated accordingly since
948/// their two conditions might not cover the same set of cases. In that case, we
949/// have to combine them (by computing the disjunction of them).
950/// Example:
951/// ```mlir
952/// bufferization.dealloc (%arg0, %arg0 : ...) if (%arg1, %arg2)
953/// ```
954/// is canonicalized to
955/// ```mlir
956/// %0 = arith.ori %arg1, %arg2 : i1
957/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%0)
958/// ```
959struct DeallocRemoveDuplicateDeallocMemrefs
960 : public OpRewritePattern<DeallocOp> {
961 using OpRewritePattern<DeallocOp>::OpRewritePattern;
962
963 LogicalResult matchAndRewrite(DeallocOp deallocOp,
964 PatternRewriter &rewriter) const override {
965 // Unique memrefs to be deallocated.
966 DenseMap<Value, unsigned> memrefToCondition;
967 SmallVector<Value> newMemrefs, newConditions;
968 for (auto [i, memref, cond] :
969 llvm::enumerate(deallocOp.getMemrefs(), deallocOp.getConditions())) {
970 if (memrefToCondition.count(memref)) {
971 // If the dealloc conditions don't match, we need to make sure that the
972 // dealloc happens on the union of cases.
973 Value &newCond = newConditions[memrefToCondition[memref]];
974 if (newCond != cond)
975 newCond =
976 arith::OrIOp::create(rewriter, deallocOp.getLoc(), newCond, cond);
977 } else {
978 memrefToCondition.insert({memref, newConditions.size()});
979 newMemrefs.push_back(memref);
980 newConditions.push_back(cond);
981 }
982 }
983
984 // Return failure if we don't change anything such that we don't run into an
985 // infinite loop of pattern applications.
986 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
987 rewriter);
988 }
989};
990
991/// Remove duplicate values in the list of retained memrefs. We need to make
992/// sure the corresponding result condition value is replaced properly.
993/// Example:
994/// ```mlir
995/// %0:2 = bufferization.dealloc retain (%arg3, %arg3 : ...)
996/// ```
997/// is canonicalized to
998/// ```mlir
999/// %0 = bufferization.dealloc retain (%arg3 : memref<2xi32>)
1000/// ```
1001struct DeallocRemoveDuplicateRetainedMemrefs
1002 : public OpRewritePattern<DeallocOp> {
1003 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1004
1005 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1006 PatternRewriter &rewriter) const override {
1007 // Unique retained values
1009 SmallVector<Value> newRetained;
1010 SmallVector<unsigned> resultReplacementIdx;
1011 unsigned i = 0;
1012 for (auto retained : deallocOp.getRetained()) {
1013 if (seen.count(retained)) {
1014 resultReplacementIdx.push_back(seen[retained]);
1015 continue;
1016 }
1017
1018 seen[retained] = i;
1019 newRetained.push_back(retained);
1020 resultReplacementIdx.push_back(i++);
1021 }
1022
1023 // Return failure if we don't change anything such that we don't run into an
1024 // infinite loop of pattern applications.
1025 if (newRetained.size() == deallocOp.getRetained().size())
1026 return failure();
1027
1028 // We need to create a new op because the number of results is always the
1029 // same as the number of condition operands.
1030 auto newDeallocOp =
1031 DeallocOp::create(rewriter, deallocOp.getLoc(), deallocOp.getMemrefs(),
1032 deallocOp.getConditions(), newRetained);
1033 SmallVector<Value> replacements(
1034 llvm::map_range(resultReplacementIdx, [&](unsigned idx) {
1035 return newDeallocOp.getUpdatedConditions()[idx];
1036 }));
1037 rewriter.replaceOp(deallocOp, replacements);
1038 return success();
1039 }
1040};
1041
1042/// Erase deallocation operations where the variadic list of memrefs to
1043/// deallocate is empty. Example:
1044/// ```mlir
1045/// %0 = bufferization.dealloc retain (%arg0: memref<2xi32>)
1046/// ```
1047struct EraseEmptyDealloc : public OpRewritePattern<DeallocOp> {
1048 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1049
1050 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1051 PatternRewriter &rewriter) const override {
1052 if (deallocOp.getMemrefs().empty()) {
1053 Value constFalse = arith::ConstantOp::create(rewriter, deallocOp.getLoc(),
1054 rewriter.getBoolAttr(false));
1055 rewriter.replaceOp(
1056 deallocOp, SmallVector<Value>(deallocOp.getUpdatedConditions().size(),
1057 constFalse));
1058 return success();
1059 }
1060 return failure();
1061 }
1062};
1063
1064/// Removes memrefs from the deallocation list if their associated condition is
1065/// always 'false'.
1066///
1067/// Example:
1068/// ```
1069/// bufferization.dealloc (%arg0, %arg1 : memref<2xi32>, memref<2xi32>)
1070/// if (%arg2, %false)
1071/// ```
1072/// becomes
1073/// ```
1074/// bufferization.dealloc (%arg0 : memref<2xi32>) if (%arg2)
1075/// ```
1076struct EraseAlwaysFalseDealloc : public OpRewritePattern<DeallocOp> {
1077 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1078
1079 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1080 PatternRewriter &rewriter) const override {
1081 SmallVector<Value> newMemrefs, newConditions;
1082 for (auto [memref, cond] :
1083 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1084 if (!matchPattern(cond, m_Zero())) {
1085 newMemrefs.push_back(memref);
1086 newConditions.push_back(cond);
1087 }
1088 }
1089
1090 return updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1091 rewriter);
1092 }
1093};
1094
1095/// The `memref.extract_strided_metadata` is often inserted to get the base
1096/// memref if the operand is not already guaranteed to be the result of a memref
1097/// allocation operation. This canonicalization pattern removes this extraction
1098/// operation if the operand is now produced by an allocation operation (e.g.,
1099/// due to other canonicalizations simplifying the IR).
1100///
1101/// Example:
1102/// ```mlir
1103/// %alloc = memref.alloc() : memref<2xi32>
1104/// %base_memref, %offset, %size, %stride = memref.extract_strided_metadata
1105/// %alloc : memref<2xi32> -> memref<i32>, index, index, index
1106/// bufferization.dealloc (%base_memref : memref<i32>) if (%cond)
1107/// ```
1108/// is canonicalized to
1109/// ```mlir
1110/// %alloc = memref.alloc() : memref<2xi32>
1111/// bufferization.dealloc (%alloc : memref<2xi32>) if (%cond)
1112/// ```
1113struct SkipExtractMetadataOfAlloc : public OpRewritePattern<DeallocOp> {
1114 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1115
1116 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1117 PatternRewriter &rewriter) const override {
1118 SmallVector<Value> newMemrefs(
1119 llvm::map_range(deallocOp.getMemrefs(), [&](Value memref) {
1120 auto extractStridedOp =
1121 memref.getDefiningOp<memref::ExtractStridedMetadataOp>();
1122 if (!extractStridedOp)
1123 return memref;
1124 Value allocMemref = extractStridedOp.getOperand();
1125 auto allocOp = allocMemref.getDefiningOp<MemoryEffectOpInterface>();
1126 if (!allocOp)
1127 return memref;
1128 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(allocMemref))
1129 return allocMemref;
1130 return memref;
1131 }));
1132
1133 return updateDeallocIfChanged(deallocOp, newMemrefs,
1134 deallocOp.getConditions(), rewriter);
1135 }
1136};
1137
1138/// Removes pairs of `bufferization.dealloc` and alloc operations if there is no
1139/// other user of the allocated value and the allocating operation can be safely
1140/// removed. If the same value is present multiple times, this pattern relies on
1141/// other canonicalization patterns to remove the duplicate first.
1142///
1143/// Example:
1144/// ```mlir
1145/// %alloc = memref.alloc() : memref<2xi32>
1146/// bufferization.dealloc (%alloc, %arg0, : ...) if (%true, %true)
1147/// ```
1148/// is canonicalized to
1149/// ```mlir
1150/// bufferization.dealloc (%arg0 : ...) if (%true)
1151/// ```
1152struct RemoveAllocDeallocPairWhenNoOtherUsers
1153 : public OpRewritePattern<DeallocOp> {
1154 using OpRewritePattern<DeallocOp>::OpRewritePattern;
1155
1156 LogicalResult matchAndRewrite(DeallocOp deallocOp,
1157 PatternRewriter &rewriter) const override {
1158 SmallVector<Value> newMemrefs, newConditions;
1159 SmallVector<Operation *> toDelete;
1160 for (auto [memref, cond] :
1161 llvm::zip(deallocOp.getMemrefs(), deallocOp.getConditions())) {
1162 if (auto allocOp = memref.getDefiningOp<MemoryEffectOpInterface>()) {
1163 // Check that it is indeed an allocate effect, that the op has no other
1164 // side effects (which would not allow us to remove the op), and that
1165 // there are no other users.
1166 if (allocOp.getEffectOnValue<MemoryEffects::Allocate>(memref) &&
1168 memref.hasOneUse()) {
1169 toDelete.push_back(allocOp);
1170 continue;
1171 }
1172 }
1173
1174 newMemrefs.push_back(memref);
1175 newConditions.push_back(cond);
1176 }
1177
1178 if (failed(updateDeallocIfChanged(deallocOp, newMemrefs, newConditions,
1179 rewriter)))
1180 return failure();
1181
1182 for (Operation *op : toDelete)
1183 rewriter.eraseOp(op);
1184
1185 return success();
1186 }
1187};
1188
1189} // anonymous namespace
1190
1191void DeallocOp::getCanonicalizationPatterns(RewritePatternSet &results,
1192 MLIRContext *context) {
1194}
1195
1198 patterns.add<DeallocRemoveDuplicateDeallocMemrefs,
1199 DeallocRemoveDuplicateRetainedMemrefs, EraseEmptyDealloc,
1200 EraseAlwaysFalseDealloc, SkipExtractMetadataOfAlloc,
1201 RemoveAllocDeallocPairWhenNoOtherUsers>(context);
1202}
1203
1204//===----------------------------------------------------------------------===//
1205// TableGen'd op method definitions
1206//===----------------------------------------------------------------------===//
1207
1208#define GET_OP_CLASSES
1209#include "mlir/Dialect/Bufferization/IR/BufferizationOps.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static SmallVector< Value > getDynamicSize(Value memref, func::FuncOp funcOp)
Return the dynamic shapes of the memref based on the defining op.
static LogicalResult updateDeallocIfChanged(DeallocOp deallocOp, ValueRange memrefs, ValueRange conditions, PatternRewriter &rewriter)
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
true
Given two iterators into the same block, return "true" if a is before `b.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
auto load
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
static llvm::ManagedStatic< PassManagerOptions > options
template bool mlir::hasSingleEffect< MemoryEffects::Allocate >(Operation *)
static void getDynamicSizes(RankedTensorType tp, ValueRange sizes, SmallVectorImpl< Value > &dynSizes)
Collects the dynamic dimension sizes for tp with the assumption that sizes are the dimension sizes fo...
Base class for generic analysis states.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseCustomTypeWithFallback(Type &result, function_ref< ParseResult(Type &result)> parseType)=0
Parse a custom type with the provided callback, unless the next token is #, in which case the generic...
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
void printStrippedAttrOrType(AttrOrType attrOrType)
Print the provided attribute in the context of an operation custom printer/parser: this will invoke d...
Attributes are known-constant values of operations.
Definition Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
BoolAttr getBoolAttr(bool value)
Definition Builders.cpp:100
IndexType getIndexType()
Definition Builders.cpp:51
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
This is a value defined by a result of an operation.
Definition Value.h:457
Simple wrapper around a void* in order to express generically how to pass in op properties through AP...
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class provides an abstraction over the different types of ranges over Regions.
Definition Region.h:346
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
Tensor types represent multi-dimensional arrays, and have two variants: RankedTensorType and Unranked...
ArrayRef< int64_t > getShape() const
Returns the shape of this tensor type.
bool hasRank() const
Returns if this type is ranked, i.e. it has a known number of dimensions.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
void populateDeallocOpCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context)
Add the canonicalization patterns for bufferization.dealloc to the given pattern set to make them ava...
FailureOr< Value > castOrReallocMemRefValue(OpBuilder &b, Value value, MemRefType type, const BufferizationOptions &options)
Try to cast the given ranked MemRef-typed value to the given ranked MemRef type.
LogicalResult foldToBufferToTensorPair(RewriterBase &rewriter, ToBufferOp toBuffer, const BufferizationOptions &options)
Try to fold to_buffer(to_tensor(x)).
void populateDynamicDimSizes(OpBuilder &b, Location loc, Value shapedValue, SmallVector< Value > &dynamicDims)
Populate dynamicDims with tensor::DimOp / memref::DimOp results for all dynamic dimensions of the giv...
Type getTensorTypeFromMemRefType(Type type)
Return an unranked/ranked tensor type for the given unranked/ranked memref type.
Definition MemRefOps.cpp:62
std::optional< Operation * > findDealloc(Value allocValue)
Finds a single dealloc operation for the given allocated value.
LogicalResult foldMemRefCast(Operation *op, Value inner=nullptr)
This is a common utility used for patterns of the form "someop(memref.cast) -> someop".
Definition MemRefOps.cpp:47
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:68
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult verifyDynamicDimensionCount(Operation *op, ShapedType type, ValueRange dynamicSizes)
Verify that the number of dynamic size operands matches the number of dynamic dimensions in the shape...
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
SmallVector< SmallVector< OpFoldResult > > ReifiedRankedShapedTypeDims
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
const FrozenRewritePatternSet & patterns
LogicalResult verifyRanksMatch(Operation *op, ShapedType lhs, ShapedType rhs, StringRef lhsName, StringRef rhsName)
Verify that two shaped types have matching ranks.
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:118
llvm::function_ref< Fn > function_ref
Definition LLVM.h:144
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.