37 int64_t ratio = control.
ratio;
38 unsigned insertSplitIndex = control.
index;
39 unsigned insertSplitDimension = control.
index;
44 op.getReductionDims(dims);
48 unsigned reductionDim = dims[0];
50 insertSplitDimension = reductionDim + 1;
53 int64_t reductionDimSize = loopRanges[reductionDim];
54 if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
56 op,
"Reduction dimension not divisible by split ratio");
57 if (op.getNumDpsInits() != 1)
59 if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
61 "compared to intermediate tensor size");
65 combinerOps.size() != 1)
70 if (!identity.has_value())
77 for (
OpOperand *operand : op.getDpsInputOperands()) {
78 AffineMap map = op.getMatchingIndexingMap(operand);
83 for (
unsigned idx : llvm::seq<unsigned>(0, map.
getNumResults())) {
85 if (reductionDim == dim) {
87 newShape.push_back(op.getShape(operand)[idx] / ratio);
88 newShape.push_back(ratio);
93 newShape.push_back(ratio);
94 newShape.push_back(op.getShape(operand)[idx] / ratio);
99 reassociation.push_back({index++, index++});
102 newShape.push_back(op.getShape(operand)[idx]);
105 reassociation.push_back({index++});
110 if (newShape == op.getShape(operand)) {
111 newInputs.push_back(operand->get());
116 cast<RankedTensorType>(operand->get().getType()).getElementType());
119 loc, newType, operand->get(), reassociation);
120 newInputs.push_back(newInput);
126 AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
129 for (
unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130 if (insertSplitIndex == idx) {
131 newOutputShape.push_back(ratio);
134 if (idx < oldShape.size()) {
135 newOutputShape.push_back(oldShape[idx]);
137 outputExpr.push_back(
141 Value emptyOrAllocTensor;
143 emptyOrAllocTensor = b.
create<bufferization::AllocTensorOp>(
146 op.getRegionOutputArgs()[0].getType()),
149 emptyOrAllocTensor = b.
create<tensor::EmptyOp>(
150 loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
152 Value constantOp = b.
create<arith::ConstantOp>(loc, *identity);
153 Value identityTensor =
154 b.
create<linalg::FillOp>(op->
getLoc(), constantOp, emptyOrAllocTensor)
160 for (
auto [index, iteratorType] :
162 if (insertSplitDimension == index)
163 newIteratorTypes.push_back(utils::IteratorType::parallel);
164 newIteratorTypes.push_back(iteratorType);
166 if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167 newIteratorTypes.push_back(utils::IteratorType::parallel);
171 GenericOp genericOp = b.
create<GenericOp>(
173 ValueRange({identityTensor}), newMaps, newIteratorTypes);
175 genericOp.getRegion().
begin());
179 unsigned intermRank = newOutputShape.size();
183 for (
unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184 if (insertSplitIndex == i) {
185 reductionIteratorTypes.push_back(utils::IteratorType::reduction);
188 reductionIteratorTypes.push_back(utils::IteratorType::parallel);
194 auto reduction = b.
create<GenericOp>(
196 op.getDpsInits(), reductionMaps, reductionIteratorTypes,
206 identityTensor.getDefiningOp<FillOp>(),
207 cast<LinalgOp>(genericOp.getOperation()),
216 unsigned reductionDimPos,
217 int64_t reductionRatio) {
220 AffineMap map = op.getMatchingIndexingMap(&opOperand);
225 reductionDim, reductionDim * reductionRatio + reductionDimP1,
227 return map.
compose(composeMap);
231 unsigned reductionDimPos, int64_t size) {
233 AffineMap map = op.getMatchingIndexingMap(&opOperand);
252 int64_t splitFactor = control.
ratio;
253 unsigned insertSplitDimension = control.
index;
254 if (splitFactor <= 1)
258 op.getReductionDims(dims);
262 unsigned reductionDimPos = dims[0];
264 int64_t reductionDimSize = loopRanges[reductionDimPos];
265 if (reductionDimSize == ShapedType::kDynamic ||
266 reductionDimSize % splitFactor != 0 ||
267 insertSplitDimension >= loopRanges.size())
269 op,
"first reduction dimension not divisible by split factor");
276 for (
Operation *reductionOp : combinerOps) {
277 std::optional<TypedAttr> neutralElement =
279 if (!neutralElement.has_value())
281 neutralElements.push_back(*neutralElement);
283 if (!llvm::all_of(neutralElements, [](
Attribute attr) {
return attr; }))
287 if (op.getNumDpsInits() !=
static_cast<int64_t
>(neutralElements.size()))
307 newOutputs.reserve(op.getNumDpsInits());
310 fillOps.reserve(op.getNumDpsInits());
311 for (
auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312 Value rankedTensor = std::get<0>(it).get();
313 auto t = cast<RankedTensorType>(rankedTensor.
getType());
315 reductionDimSize / splitFactor, insertSplitDimension);
318 Value emptyOrAllocTensor;
321 b.
create<bufferization::AllocTensorOp>(loc, newT, dims);
323 emptyOrAllocTensor = b.
create<tensor::EmptyOp>(loc, newT.getShape(),
324 t.getElementType(), dims);
326 Value constantOp = b.
create<arith::ConstantOp>(loc, std::get<1>(it));
328 b.
create<linalg::FillOp>(op->
getLoc(), constantOp, emptyOrAllocTensor));
329 newOutputs.push_back(fillOps.back().getResult(0));
330 emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.
getDefiningOp());
337 for (
OpOperand *o : op.getDpsInputOperands())
340 auto nDims = op.getNumLoops() + 1;
343 newMaps.push_back(
AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
348 for (
OpOperand &o : op.getDpsInitsMutable())
350 reductionDimSize / splitFactor));
357 newInputs.push_back(b.
create<tensor::EmptyOp>(
364 auto iteratorTypes = op.getIteratorTypesArray();
365 iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366 utils::IteratorType::parallel);
367 GenericOp genericOp =
369 newOutputs, newMaps, iteratorTypes);
371 genericOp.getRegion().
begin());
372 genericOp.getRegion().front().insertArgument(reductionDimPos,
384 llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385 Value reindexedOutput = std::get<0>(it);
386 Value originalOutput = std::get<1>(it);
387 auto originalOutputType = cast<RankedTensorType>(originalOutput.
getType());
394 originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395 reductionIteratorTypes[insertSplitDimension] =
396 utils::IteratorType::reduction;
399 auto reductionOp = b.
create<GenericOp>(
405 reductionIteratorTypes,
410 b.create<linalg::YieldOp>(loc, clonedReductionOp->
getResult(0));
414 results.push_back(reductionOp);
418 assert(fillOps.size() == results.size() && results.size() == 1);
419 b.
replaceOp(op, results.front()->getResults());
421 cast<LinalgOp>(genericOp.getOperation()),
433 controlSplitReductionFn(std::move(controlSplitReductionFn)),
434 useAlloc(useAlloc) {}
436 LogicalResult matchAndRewrite(LinalgOp op,
438 return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
452 controlSplitReductionFn, useAlloc);
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t reductionRatio)
Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) TODO: Additional pattern to rewrite f(i,...
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
AffineMap dropResult(int64_t pos) const
Returns a new AffineMap with the same number of dims and symbols and one less result at pos,...
unsigned getNumDims() const
unsigned getNumResults() const
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Attributes are known-constant values of operations.
AffineMap getMultiDimIdentityMap(unsigned rank)
IntegerType getIntegerType(unsigned width)
AffineExpr getAffineDimExpr(unsigned position)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
This class represents an operand of an operation.
Operation is the basic unit of execution within MLIR.
void setOperand(unsigned idx, Value value)
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
MLIRContext * getContext()
Return the context this operation is associated with.
Location getLoc()
The source location the operation was defined or derived from.
unsigned getNumOperands()
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
result_type_range getResultTypes()
result_range getResults()
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This is a builder type that keeps local references to arguments.
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region ®ion, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Patterns to apply splitReduction below.
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension.