26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/Support/Debug.h"
37 OperandRange::iterator &elementIt,
39 if (dim ==
static_cast<int>(shape.size()) - 1) {
40 for (
int i = 0; i < shape.back(); ++i) {
41 indices.back() = constants[i];
42 destination = rewriter.
create<tensor::InsertOp>(loc, *elementIt,
43 destination, indices);
48 for (
int i = 0; i < shape[dim]; ++i) {
49 indices[dim] = constants[i];
50 destination =
createInserts(rewriter, loc, dim + 1, destination, shape,
51 constants, elementIt, indices);
61 auto tensorType = dyn_cast<RankedTensorType>(tensorSource.
getType());
62 assert(tensorType &&
"expected ranked tensor");
63 assert(isa<MemRefType>(memrefDest.
getType()) &&
"expected ranked memref");
70 auto materializeOp = b.
create<bufferization::MaterializeInDestinationOp>(
71 loc, tensorSource, memrefDest);
72 materializeOp.setWritable(
true);
78 Value toMemref = b.
create<bufferization::ToMemrefOp>(
81 b.
create<memref::CopyOp>(loc, toMemref, memrefDest);
87 Value toMemref = b.
create<bufferization::ToMemrefOp>(
90 b.
create<linalg::CopyOp>(loc, toMemref, memrefDest);
99 RankedTensorType resultType = padOp.getResultType();
104 cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
108 isa<BlockArgument>(yieldedValue) &&
109 cast<BlockArgument>(yieldedValue).getOwner()->getParentOp() !=
110 padOp.getOperation();
112 bool outsideOpResult =
113 isa<OpResult>(yieldedValue) &&
115 bool invariantYieldedValue = outsideBbArg || outsideOpResult;
130 if (invariantYieldedValue) {
132 auto fillOp = rewriter.
create<linalg::FillOp>(loc,
ValueRange(yieldedValue),
139 utils::IteratorType::parallel);
142 auto genericOp = rewriter.
create<linalg::GenericOp>(
145 indexingMaps, iteratorTypes);
147 resultType.getElementType(), loc);
150 for (int64_t i = 0; i < resultType.getRank(); ++i)
151 bbArgReplacements.push_back(rewriter.
create<linalg::IndexOp>(loc, i));
152 rewriter.
mergeBlocks(padOp.getBody(), body, bbArgReplacements);
155 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
162 auto tensorType = cast<RankedTensorType>(value.
getType());
163 if (tensorType.hasStaticShape())
168 if (isa<OpResult>(value) &&
171 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
172 if (tensorType.isDynamicDim(i))
173 dynSizes.push_back(cast<Value>(
174 reifiedShape[cast<OpResult>(value).getResultNumber()][i]));
181 for (int64_t i = 0; i < tensorType.getRank(); ++i) {
182 if (tensorType.isDynamicDim(i))
195 auto tensorType = cast<RankedTensorType>(value.
getType());
200 tensorType, memorySpace));
206 alloc = rewriter.
create<memref::AllocOp>(loc, memrefType, dynamicSizes);
210 rewriter.
create<memref::DeallocOp>(loc, alloc);
214 alloc = rewriter.
create<memref::AllocaOp>(loc, memrefType, dynamicSizes);
225 assert(!
options.bufferizeDestinationOnly &&
"invalid options");
236 if (!padOp.hasZeroLowPad() || !padOp.hasZeroHighPad()) {
248 Value subview = rewriter.
create<memref::SubViewOp>(
249 loc, alloc, padOp.getMixedLowPad(), sizes, strides);
254 Value toTensorOp = rewriter.
create<bufferization::ToTensorOp>(
255 loc, alloc,
true,
true);
263 assert(llvm::range_size(maskOp.getMaskBlock()->without_terminator()) == 1 &&
264 "expected single masked op");
267 Operation *yieldOp = maskOp.getMaskRegion().front().getTerminator();
268 assert(isa<vector::YieldOp>(yieldOp) &&
"expected yield op terminator");
273 rewriter,
options, maskOp.getMaskableOp(), memorySpace,
274 insertionPoint ? insertionPoint : maskOp);
276 if (
options.bufferizeDestinationOnly)
281 if (failed(cast<bufferization::BufferizableOpInterface>(yieldOp).bufferize(
282 rewriter, bufferizationOptions)))
289 maskOp.walk([&](bufferization::ToTensorOp toTensorOp) {
290 if (toTensorOp->getUses().empty())
291 toTensorOps.push_back(toTensorOp.getOperation());
298 for (
Value result : maskOp.getResults())
299 if (isa<TensorType>(result.getType()))
301 resultUses.push_back(&use);
303 if (failed(cast<bufferization::BufferizableOpInterface>(maskOp.getOperation())
304 .bufferize(rewriter, bufferizationOptions)))
309 for (
OpOperand *resultUse : resultUses) {
311 resultUse->get().getDefiningOp<bufferization::ToTensorOp>();
312 assert(toTensorOp &&
"expected to_tensor op");
314 toTensorOp.setRestrict(
true);
315 toTensorOp.setWritable(
true);
324 bufferization::AllocTensorOp allocTensorOp,
Attribute memorySpace,
326 Location loc = allocTensorOp.getLoc();
333 rewriter, loc, allocTensorOp.getResult(),
options, memorySpace);
337 Value toTensorOp = rewriter.
create<bufferization::ToTensorOp>(
338 loc, alloc,
true,
true);
339 rewriter.
replaceOp(allocTensorOp, toTensorOp);
345 RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp) {
346 Location loc = fromElementsOp.getLoc();
347 RankedTensorType tensorType =
348 cast<RankedTensorType>(fromElementsOp.getType());
349 auto shape = tensorType.getShape();
357 fromElementsOp, fromElementsOp.getElements().front(),
363 auto maxDim = *llvm::max_element(shape);
365 constants.reserve(maxDim);
366 for (
int i = 0; i < maxDim; ++i)
370 auto elementIt = fromElementsOp.getElements().begin();
373 shape, constants, elementIt, indices);
376 rewriter.
replaceOp(fromElementsOp, result);
381 FailureOr<Operation *>
383 tensor::GenerateOp generateOp) {
385 if (!generateOp.getBody().hasOneBlock())
389 RankedTensorType tensorType = cast<RankedTensorType>(generateOp.getType());
393 rewriter.
create<EmptyOp>(loc, tensorType, generateOp.getDynamicExtents());
397 utils::IteratorType::parallel);
400 auto genericOp = rewriter.
create<linalg::GenericOp>(
403 indexingMaps, iteratorTypes);
405 tensorType.getElementType(), loc);
408 for (int64_t i = 0; i < tensorType.getRank(); ++i)
409 bbArgReplacements.push_back(rewriter.
create<linalg::IndexOp>(loc, i));
410 rewriter.
mergeBlocks(&generateOp.getBody().front(), body, bbArgReplacements);
413 auto yieldOp = cast<tensor::YieldOp>(body->getTerminator());
417 rewriter.
replaceOp(generateOp, genericOp->getResult(0));
418 return genericOp.getOperation();
422 FailureOr<Operation *>
424 tensor::PadOp padOp) {
426 if (!padOp.getBodyRegion().hasOneBlock())
431 RankedTensorType resultType = padOp.getResultType();
435 padOp,
"failed to reify tensor.pad op result shape");
437 for (int64_t i = 0; i < resultType.getRank(); ++i)
438 if (resultType.isDynamicDim(i))
439 dynamicSizes.push_back(cast<Value>(reifiedShape[0][i]));
443 if (padOp.getNofoldAttr() &&
444 llvm::all_of(padOp.getMixedLowPad(),
isZeroIndex) &&
445 llvm::all_of(padOp.getMixedHighPad(),
isZeroIndex)) {
446 using bufferization::AllocTensorOp;
448 rewriter.
create<AllocTensorOp>(loc, resultType, dynamicSizes);
450 padOp, padOp.getSource(), allocated);
451 return copyOp.getOperation();
454 Value empty = rewriter.
create<EmptyOp>(loc, resultType, dynamicSizes);
465 padOp, padOp.getSource(), fillOp->getResult(0),
466 padOp.getMixedLowPad(), sliceSizes, sliceStrides);
467 return insertSliceOp.getOperation();
473 using namespace bufferization;
476 if (
auto padOp = dyn_cast<tensor::PadOp>(op))
478 if (
auto maskOp = dyn_cast<vector::MaskOp>(op))
480 if (
auto allocTensorOp = dyn_cast<bufferization::AllocTensorOp>(op))
484 auto bufferizableOp = dyn_cast<BufferizableOpInterface>(op);
487 BufferizationOptions bufferizationOptions;
491 if (!
options.bufferizeDestinationOnly) {
498 [](
Value v) { return isa<TensorType>(v.getType()); }))
499 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
501 [](
Value v) { return isa<TensorType>(v.getType()); }))
502 llvm_unreachable(
"ops with nested tensor ops are not supported yet");
510 if (!isa<TensorType>(result.getType()))
513 if (!isa<RankedTensorType>(result.getType()))
516 if (bufferizableOp.bufferizesToAllocation(result))
518 tensorResults.push_back(result);
524 auto addOutOfPlaceOperand = [&](
OpOperand *operand) {
525 if (!llvm::is_contained(outOfPlaceOperands, operand))
526 outOfPlaceOperands.push_back(operand);
528 for (
OpResult result : tensorResults) {
530 state.getAliasingOpOperands(result);
531 for (
const AliasingOpOperand &operand : aliasingOperands) {
532 addOutOfPlaceOperand(operand.opOperand);
533 for (
OpOperand &resultUse : result.getUses())
534 resultUses.push_back(&resultUse);
538 if (!state.bufferizesToMemoryWrite(operand))
540 if (!isa<RankedTensorType>(operand.get().getType()))
542 addOutOfPlaceOperand(&operand);
545 if (outOfPlaceOperands.size() != 1)
552 for (
OpOperand *operand : outOfPlaceOperands) {
554 rewriter, op->
getLoc(), operand->get(),
options, memorySpace);
555 allocs.push_back(alloc);
556 if (!state.findDefinitions(operand->get()).empty()) {
562 auto toTensorOp = rewriter.
create<ToTensorOp>(op->
getLoc(), alloc);
563 operand->set(toTensorOp);
564 if (
options.bufferizeDestinationOnly) {
566 toTensorOp.setRestrict(
true);
567 toTensorOp.setWritable(
true);
573 if (
options.bufferizeDestinationOnly)
574 return allocs.front();
578 if (failed(bufferizableOp.bufferize(rewriter, bufferizationOptions)))
583 for (
OpOperand *resultUse : resultUses) {
584 auto toTensorOp = resultUse->get().
getDefiningOp<ToTensorOp>();
585 assert(toTensorOp &&
"expected to_tensor op");
587 toTensorOp.setRestrict(
true);
588 toTensorOp.setWritable(
true);
591 return allocs.front();
596 template <
typename OpTy>
597 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
606 patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
607 patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
608 patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
static Operation * movePaddingToFillOrGenericOp(RewriterBase &rewriter, Location loc, PadOp padOp, Value dest)
static Value createAllocationForTensor(RewriterBase &rewriter, Location loc, Value value, const linalg::BufferizeToAllocationOptions &options, Attribute memorySpace={})
static void createMemcpy(OpBuilder &b, Location loc, Value tensorSource, Value memrefDest, const linalg::BufferizeToAllocationOptions &options)
Create a memcpy from the given source tensor to the given destination memref.
static SmallVector< Value > reifyOrComputeDynamicSizes(OpBuilder &b, Value value)
static Value createInserts(RewriterBase &rewriter, Location loc, int dim, Value destination, ArrayRef< int64_t > shape, ArrayRef< Value > constants, OperandRange::iterator &elementIt, SmallVectorImpl< Value > &indices)
static llvm::ManagedStatic< PassManagerOptions > options
Base class for generic analysis states.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
Operation * getTerminator()
Get the terminator operation of this block.
IntegerAttr getIndexAttr(int64_t value)
AffineMap getMultiDimIdentityMap(unsigned rank)
MLIRContext * getContext() const
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
virtual Operation * materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc)
Registered hook to materialize a single constant operation from a given attribute value with the desi...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Dialect * getLoadedDialect(StringRef name)
Get a registered IR dialect with the given namespace.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes=std::nullopt, ArrayRef< Location > locs=std::nullopt)
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Block * getInsertionBlock() const
Return the block the current insertion point belongs to.
This class represents an operand of an operation.
This is a value defined by a result of an operation.
Operation is the basic unit of execution within MLIR.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
MutableArrayRef< OpOperand > getOpOperands()
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void mergeBlocks(Block *source, Block *dest, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into the end of block 'dest'.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Location getLoc() const
Return the location of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Specialization of arith.constant op that returns an integer of index type.
BaseMemRefType getMemRefTypeWithStaticIdentityLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with a static identity layout (i.e., no layout map).
AliasList< AliasingOpOperand > AliasingOpOperandList
A list of possible aliasing OpOperands.
BaseMemRefType getMemRefTypeWithFullyDynamicLayout(TensorType tensorType, Attribute memorySpace=nullptr)
Return a MemRef type with fully dynamic layout.
Value bufferizeToAllocation(RewriterBase &rewriter, const BufferizeToAllocationOptions &options, tensor::PadOp padOp, Attribute memorySpace={}, Operation *insertionPoint=nullptr)
Materialize a buffer allocation for the given tensor.pad op and lower the op to linalg....
FailureOr< Operation * > rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::FromElementsOp fromElementsOp)
Rewrite tensor.from_elements to linalg.generic.
void populateConvertToDestinationStylePatterns(RewritePatternSet &patterns)
Populate patterns that convert non-destination-style ops to destination style ops.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
bool isZeroIndex(OpFoldResult v)
Return true if v is an IntegerAttr with value 0 of a ConstantIndexOp with attribute with value 0.
LogicalResult reifyResultShapes(OpBuilder &b, Operation *op, ReifiedRankedShapedTypeDims &reifiedReturnShapes)
Reify the shape of the result of an operation (typically in terms of the shape of its operands).
const FrozenRewritePatternSet & patterns
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Options for BufferizableOpInterface-based bufferization.
@ MaterializeInDestination