25 #include "llvm/ADT/STLExtras.h"
35 OperandRange::iterator &elementIt,
37 if (dim ==
static_cast<int>(shape.size()) - 1) {
38 for (
int i = 0; i < shape.back(); ++i) {
39 indices.back() = constants[i];
40 destination = tensor::InsertOp::create(rewriter, loc, *elementIt,
41 destination, indices);
46 for (
int i = 0; i < shape[dim]; ++i) {
47 indices[dim] = constants[i];
48 destination =
createInserts(rewriter, loc, dim + 1, destination, shape,
49 constants, elementIt, indices);
59 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.
getType());
60 assert(tensorType &&
"expected ranked tensor");
61 assert(isa<MemRefType>(memrefDest.
getType()) &&
"expected ranked memref");
68 auto materializeOp = bufferization::MaterializeInDestinationOp::create(
69 b, loc, tensorSource, memrefDest);
70 materializeOp.setWritable(
true);
76 Value toBuffer = bufferization::ToBufferOp::create(
79 memref::CopyOp::create(b, loc, toBuffer, memrefDest);
85 Value toBuffer = bufferization::ToBufferOp::create(
88 linalg::CopyOp::create(b, loc, toBuffer, memrefDest);
97 RankedTensorType resultType = padOp.getResultType();
102 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
106 isa<BlockArgument>(yieldedValue) &&
107 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
108 padOp.getOperation();
110 bool outsideOpResult =
111 isa<OpResult>(yieldedValue) &&
113 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
123 auto fillOp = linalg::FillOp::create(rewriter, loc,
ValueRange(fillValue),
128 if (invariantYieldedValue) {
130 auto fillOp = linalg::FillOp::create(
137 utils::IteratorType::parallel);
140 auto genericOp = linalg::GenericOp::create(
143 indexingMaps, iteratorTypes);
145 resultType.getElementType(), loc);
148 for (int64_t i = 0; i < resultType.getRank(); ++i)
149 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
150 rewriter.
mergeBlocks(padOp.getBody(), body, bbArgReplacements);
153 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
160 auto tensorType = cast<RankedTensorType>(value.
getType());
161 if (tensorType.hasStaticShape())
166 if (isa<OpResult>(value) &&
169 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
170 if (tensorType.isDynamicDim(i))
171 dynSizes.push_back(cast<Value>(
172 reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
179 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
180 if (tensorType.isDynamicDim(i))
182 DimOp::create(b, value.
getLoc(), value,
193 auto tensorType = cast<RankedTensorType>(value.
getType());
198 tensorType, memorySpace));
204 alloc = memref::AllocOp::create(rewriter, loc, memrefType, dynamicSizes);
208 memref::DeallocOp::create(rewriter, loc, alloc);
212 alloc = memref::AllocaOp::create(rewriter, loc, memrefType, dynamicSizes);
223 assert(!
options.bufferizeDestinationOnly &&
"invalid options");
234 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
246 Value subview = memref::SubViewOp::create(
247 rewriter, loc, alloc, padOp.getMixedLowPad(), sizes, strides);
252 Value toTensorOp = bufferization::ToTensorOp::create(
253 rewriter, loc, padOp.getResult().getType(), alloc,
true,
262 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
263 "expected single masked op");
270 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
271 assert(isa<vector::YieldOp>(yieldOp) &&
"expected yield op terminator");
276 rewriter,
options, maskOp.getMaskableOp(), memorySpace,
277 insertionPoint ? insertionPoint : maskOp);
279 if (
options.bufferizeDestinationOnly)
284 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
285 rewriter, bufferizationOptions, bufferizationState)))
292 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
293 if (toTensorOp->getUses().empty())
294 toTensorOps.push_back(toTensorOp.getOperation());
301 for (
Value result : maskOp.getResults())
302 if (isa<TensorType>(result.getType()))
304 resultUses.push_back(&use);
307 cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
308 .bufferize(rewriter, bufferizationOptions, bufferizationState)))
313 for (
OpOperand *resultUse : resultUses) {
315 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
316 assert(toTensorOp &&
"expected to_tensor op");
318 toTensorOp.setRestrict(
true);
319 toTensorOp.setWritable(
true);
328 bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace,
330 Location loc = allocTensorOp.getLoc();
337 rewriter, loc, allocTensorOp.getResult(),
options, memorySpace);
341 Value toTensorOp = bufferization::ToTensorOp::create(
342 rewriter, loc, allocTensorOp.getResult().getType(), alloc,
345 rewriter.
replaceOp(allocTensorOp, toTensorOp);
351 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
352 Location loc = fromElementsOp.getLoc();
353 RankedTensorType tensorType =
354 cast<RankedTensorType>(fromElementsOp.getType());
355 auto shape = tensorType.getShape();
358 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
ValueRange());
363 fromElementsOp, fromElementsOp.getElements().front(),
369 auto maxDim = *llvm::max_element(shape);
371 constants.reserve(maxDim);
372 for (
int i = 0; i < maxDim; ++i)
376 auto elementIt = fromElementsOp.getElements().begin();
379 shape, constants, elementIt, indices);
382 rewriter.
replaceOp(fromElementsOp, result);
387 FailureOr<Operation *>
389 tensor::GenerateOp generateOp) {
391 if (!generateOp.getBody().hasOneBlock())
395 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
398 auto emptyOp = EmptyOp::create(rewriter, loc, tensorType,
399 generateOp.getDynamicExtents());
403 utils::IteratorType::parallel);
406 auto genericOp = linalg::GenericOp::create(
409 indexingMaps, iteratorTypes);
411 tensorType.getElementType(), loc);
414 for (int64_t i = 0; i < tensorType.getRank(); ++i)
415 bbArgReplacements.push_back(linalg::IndexOp::create(rewriter, loc, i));
416 rewriter.
mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
419 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
423 rewriter.
replaceOp(generateOp, genericOp->getResult(0));
424 return genericOp.getOperation();
428 FailureOr<Operation *>
430 tensor::PadOp padOp) {
432 if (!padOp.getBodyRegion().hasOneBlock())
437 RankedTensorType resultType = padOp.getResultType();
441 padOp,
"failed to reify tensor.pad op result shape");
443 for (int64_t i = 0; i < resultType.getRank(); ++i)
444 if (resultType.isDynamicDim(i))
445 dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
449 if (padOp.getNofoldAttr() &&
452 using bufferization::AllocTensorOp;
454 AllocTensorOp::create(rewriter, loc, resultType, dynamicSizes);
456 padOp, padOp.getSource(), allocated);
457 return copyOp.getOperation();
460 Value empty = EmptyOp::create(rewriter, loc, resultType, dynamicSizes);
471 padOp, padOp.getSource(), fillOp->
getResult(0),
472 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
473 return insertSliceOp.getOperation();
479 using namespace bufferization;
482 if (
auto padOp = dyn_cast<tensor::PadOp>(op))
484 if (
auto maskOp = dyn_cast<vector::MaskOp>(op))
486 if (
auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
490 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
495 BufferizationOptions bufferizationOptions;
497 BufferizationState bufferizationState;
500 if (!
options.bufferizeDestinationOnly) {
507 [](
Value v) { return isa<TensorType>(v.getType()); }))
508 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
510 [](
Value v) { return isa<TensorType>(v.getType()); }))
511 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
519 if (!isa<TensorType>(result.getType()))
522 if (!isa<RankedTensorType>(result.getType()))
525 if (bufferizableOp.bufferizesToAllocation(result))
527 tensorResults.push_back(result);
533 auto addOutOfPlaceOperand = [&](
OpOperand *operand) {
534 if (!llvm::is_contained(outOfPlaceOperands, operand))
535 outOfPlaceOperands.push_back(operand);
537 for (
OpResult result : tensorResults) {
539 analysisState.getAliasingOpOperands(result);
540 for (
const AliasingOpOperand &operand : aliasingOperands) {
541 addOutOfPlaceOperand(operand.opOperand);
542 for (
OpOperand &resultUse : result.getUses())
543 resultUses.push_back(&resultUse);
547 if (!analysisState.bufferizesToMemoryWrite(operand))
549 if (!isa<RankedTensorType>(operand.get().getType()))
551 addOutOfPlaceOperand(&operand);
554 if (outOfPlaceOperands.size() != 1)
561 for (
OpOperand *operand : outOfPlaceOperands) {
563 rewriter, op->
getLoc(), operand->get(),
options, memorySpace);
564 allocs.push_back(alloc);
565 if (!analysisState.findDefinitions(operand).empty()) {
571 auto toTensorOp = ToTensorOp::create(rewriter, op->
getLoc(),
572 operand->get().getType(), alloc);
573 operand->set(toTensorOp);
574 if (
options.bufferizeDestinationOnly) {
576 toTensorOp.setRestrict(
true);
577 toTensorOp.setWritable(
true);
583 if (
options.bufferizeDestinationOnly)
584 return allocs.front();
588 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions,
589 bufferizationState)))
594 for (
OpOperand *resultUse : resultUses) {
595 auto toTensorOp = resultUse->get().
getDefiningOp<ToTensorOp>();
596 assert(toTensorOp &&
"expected to_tensor op");
598 toTensorOp.setRestrict(
true);
599 toTensorOp.setWritable(
true);
602 return allocs.front();
607 template <
typename OpTy>
608 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
617 patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
618 patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
619 patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void mergeBlocks(Block *source, Block *dest, ValueRange argValues={})
Inline the operations of block 'source' into the end of block 'dest'.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
BufferizationState provides information about the state of the IR during the bufferization process.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet & patterns
bool isZeroInteger(OpFoldResult v)
Return true if v is an IntegerAttr with value 0.
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Options for BufferizableOpInterface-based bufferization.
@ MaterializeInDestination