34 LogicalResult matchAndRewrite(ConcatOp concatOp,
37 FailureOr<Value> dest =
42 auto empty = dest->getDefiningOp<tensor::EmptyOp>();
46 int64_t dim = concatOp.getDim();
50 int64_t rank = concatOp.getResultType().getRank();
58 for (
auto [idx, input] :
61 partialSums.push_back(sum);
62 offsetStrides.push_back(
63 rewriter.
createOrFold<tensor::DimOp>(loc, input, dimValue));
65 auto partialSumMap =
AffineMap::get(concatOp.getInputs().size(), 0,
69 rewriter, loc, partialSumMap, offsetStrides);
73 for (
auto [input, offset] :
74 llvm::zip_equal(concatOp.getInputs(), dimOffsets)) {
77 offsets[dim] = offset;
79 loc, input, result, offsets, sizes, strides);
83 concatOp, concatOp.getResultType(), result);
Base type for affine expression.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
IntegerAttr getIndexAttr(int64_t value)
AffineExpr getAffineDimExpr(unsigned position)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
SmallVector< OpFoldResult > makeComposedFoldedMultiResultAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Variant of makeComposedFoldedAffineApply suitable for multi-result maps.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateDecomposeTensorConcatPatterns(RewritePatternSet &patterns)
Populates patterns with patterns that decompose tensor.concat into tensor.empty of a tensor of the co...
FailureOr< Value > getOrCreateDestination(OpBuilder &b, Location loc, OpResult opResult)
This is a helper function for DestinationStyleOpInterface.
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Include the generated interface declarations.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...