MLIR  20.0.0git
SplitReduction.cpp
Go to the documentation of this file.
1 //===-------- SplitReduction.cpp - Split reduction dimesion ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements linalg transformation to break a reduction dimension
10 // between a parallel and a reduction dimension.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include <optional>
15 #include <utility>
16 
25 #include "mlir/IR/PatternMatch.h"
26 
27 using namespace mlir;
28 using namespace mlir::linalg;
29 
30 FailureOr<SplitReductionResult> mlir::linalg::splitReduction(
31  RewriterBase &b, LinalgOp op,
32  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
34  b.setInsertionPoint(op);
35 
36  SplitReductionOptions control = controlSplitReductionFn(op);
37  int64_t ratio = control.ratio;
38  unsigned insertSplitIndex = control.index;
39  unsigned insertSplitDimension = control.index;
40  if (ratio <= 1)
41  return b.notifyMatchFailure(op, "split ratio needs to be greater than 1");
42 
44  op.getReductionDims(dims);
45 
46  if (dims.size() != 1)
47  return b.notifyMatchFailure(op, "needs a single reduction dimension");
48  unsigned reductionDim = dims[0];
49  if (control.innerParallel) {
50  insertSplitDimension = reductionDim + 1;
51  }
52  SmallVector<int64_t, 4> loopRanges = op.getStaticLoopRanges();
53  int64_t reductionDimSize = loopRanges[reductionDim];
54  if (reductionDimSize == ShapedType::kDynamic || reductionDimSize % ratio != 0)
55  return b.notifyMatchFailure(
56  op, "Reduction dimension not divisible by split ratio");
57  if (op.getNumDpsInits() != 1)
58  return b.notifyMatchFailure(op, "More than one output in split reduction");
59  if (insertSplitIndex > op.getShape(op.getDpsInitOperand(0)).size())
60  return b.notifyMatchFailure(op, "Insert dimension position too large "
61  "compared to intermediate tensor size");
62 
63  SmallVector<Operation *, 4> combinerOps;
64  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps) ||
65  combinerOps.size() != 1)
66  return b.notifyMatchFailure(op, "Cannot match the reduction pattern");
67 
68  Operation *reductionOp = combinerOps[0];
69  std::optional<TypedAttr> identity = arith::getNeutralElement(reductionOp);
70  if (!identity.has_value())
71  return b.notifyMatchFailure(op, "Unknown identity value for the reduction");
72 
73  Location loc = op->getLoc();
74  SmallVector<Value> newInputs;
75  SmallVector<AffineMap> newMaps;
76  // Calculate the new shapes and indexing maps of the input operands.
77  for (OpOperand *operand : op.getDpsInputOperands()) {
78  AffineMap map = op.getMatchingIndexingMap(operand);
79  SmallVector<int64_t> newShape;
82  unsigned index = 0;
83  for (unsigned idx : llvm::seq<unsigned>(0, map.getNumResults())) {
84  unsigned dim = map.getDimPosition(idx);
85  if (reductionDim == dim) {
86  if (control.innerParallel) {
87  newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
88  newShape.push_back(ratio); // parallel (insert)
89  exprs.push_back(
90  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
91  exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
92  } else {
93  newShape.push_back(ratio); // parallel (insert)
94  newShape.push_back(op.getShape(operand)[idx] / ratio); // reduce
95  exprs.push_back(b.getAffineDimExpr(insertSplitDimension));
96  exprs.push_back(
97  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
98  }
99  reassociation.push_back({index++, index++});
100  continue;
101  }
102  newShape.push_back(op.getShape(operand)[idx]);
103  exprs.push_back(
104  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
105  reassociation.push_back({index++});
106  }
107  newMaps.push_back(
108  AffineMap::get(map.getNumDims() + 1, 0, exprs, op.getContext()));
109  // If the shape is unchanged the input doesn't change.
110  if (newShape == op.getShape(operand)) {
111  newInputs.push_back(operand->get());
112  continue;
113  }
114  Type newType = RankedTensorType::get(
115  newShape,
116  cast<RankedTensorType>(operand->get().getType()).getElementType());
117 
118  Value newInput = b.create<tensor::ExpandShapeOp>(
119  loc, newType, operand->get(), reassociation);
120  newInputs.push_back(newInput);
121  }
122 
123  // Calculate the new output map and shape, we insert the new dimension based
124  // on the index returned by `controlSplitReductionFn`.
125  SmallVector<int64_t> newOutputShape;
126  AffineMap oldOutputMap = op.getMatchingIndexingMap(op.getDpsInitOperand(0));
127  ArrayRef<int64_t> oldShape = op.getShape(op.getDpsInitOperand(0));
128  SmallVector<AffineExpr> outputExpr;
129  for (unsigned idx : llvm::seq<unsigned>(0, oldShape.size() + 1)) {
130  if (insertSplitIndex == idx) {
131  newOutputShape.push_back(ratio);
132  outputExpr.push_back(b.getAffineDimExpr(insertSplitDimension));
133  }
134  if (idx < oldShape.size()) {
135  newOutputShape.push_back(oldShape[idx]);
136  unsigned dim = oldOutputMap.getDimPosition(idx);
137  outputExpr.push_back(
138  b.getAffineDimExpr(dim < insertSplitDimension ? dim : dim + 1));
139  }
140  }
141  Value emptyOrAllocTensor;
142  if (useAlloc) {
143  emptyOrAllocTensor = b.create<bufferization::AllocTensorOp>(
144  loc,
145  RankedTensorType::get(newOutputShape,
146  op.getRegionOutputArgs()[0].getType()),
147  ValueRange{});
148  } else {
149  emptyOrAllocTensor = b.create<tensor::EmptyOp>(
150  loc, newOutputShape, op.getRegionOutputArgs()[0].getType());
151  }
152  Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
153  Value identityTensor =
154  b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor)
155  .getResult(0);
156 
157  newMaps.push_back(AffineMap::get(oldOutputMap.getNumDims() + 1, 0, outputExpr,
158  op.getContext()));
159  SmallVector<utils::IteratorType> newIteratorTypes;
160  for (auto [index, iteratorType] :
161  llvm::enumerate(op.getIteratorTypesArray())) {
162  if (insertSplitDimension == index)
163  newIteratorTypes.push_back(utils::IteratorType::parallel);
164  newIteratorTypes.push_back(iteratorType);
165  }
166  if (insertSplitDimension == op.getIteratorTypesArray().size()) {
167  newIteratorTypes.push_back(utils::IteratorType::parallel);
168  }
169  // Create the new op matching the original op with an extra parallel
170  // dimension.
171  GenericOp genericOp = b.create<GenericOp>(
172  loc, TypeRange({emptyOrAllocTensor.getType()}), newInputs,
173  ValueRange({identityTensor}), newMaps, newIteratorTypes);
174  b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
175  genericOp.getRegion().begin());
176 
177  // Then create a new reduction that only reduce the newly added dimension
178  // from the previous op.
179  unsigned intermRank = newOutputShape.size();
180  AffineMap inputMap = b.getMultiDimIdentityMap(intermRank);
181  SmallVector<utils::IteratorType> reductionIteratorTypes;
183  for (unsigned i : llvm::seq<unsigned>(0, intermRank)) {
184  if (insertSplitIndex == i) {
185  reductionIteratorTypes.push_back(utils::IteratorType::reduction);
186  } else {
187  exprs.push_back(b.getAffineDimExpr(i));
188  reductionIteratorTypes.push_back(utils::IteratorType::parallel);
189  }
190  }
191  AffineMap outputMap = AffineMap::get(intermRank, 0, exprs, op.getContext());
192  SmallVector<AffineMap> reductionMaps = {inputMap, outputMap};
193 
194  auto reduction = b.create<GenericOp>(
195  loc, op->getResultTypes(), ValueRange({genericOp.getResult(0)}),
196  op.getDpsInits(), reductionMaps, reductionIteratorTypes,
197  [reductionOp](OpBuilder &b, Location loc, ValueRange inputs) {
198  Operation *clonedReductionOp = b.clone(*reductionOp);
199  clonedReductionOp->setOperand(0, inputs[0]);
200  clonedReductionOp->setOperand(1, inputs[1]);
201  b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
202  });
203  b.replaceOp(op, reduction.getResults());
204 
205  return SplitReductionResult{emptyOrAllocTensor.getDefiningOp(),
206  identityTensor.getDefiningOp<FillOp>(),
207  cast<LinalgOp>(genericOp.getOperation()),
208  reduction};
209 }
210 
211 /// Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...)
212 /// TODO: Additional pattern to rewrite f(i, j, k * ratio + kk, ...) into
213 /// f(i, j, k, kk, ...) with a proper ExpandShapeOp. This is probably better
214 /// done as a transform to enable better vectorization.
215 static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand,
216  unsigned reductionDimPos,
217  int64_t reductionRatio) {
218  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
219  auto reductionDimP1 = getAffineDimExpr(reductionDimPos + 1, op.getContext());
220  AffineMap map = op.getMatchingIndexingMap(&opOperand);
221  AffineMap idMap =
222  AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
223  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
224  AffineMap composeMap = shiftedIdMap.replace(
225  reductionDim, reductionDim * reductionRatio + reductionDimP1,
226  shiftedIdMap.getNumDims(), /*numSymbols=*/0);
227  return map.compose(composeMap);
228 }
229 
230 static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand,
231  unsigned reductionDimPos, int64_t size) {
232  auto reductionDim = getAffineDimExpr(reductionDimPos, op.getContext());
233  AffineMap map = op.getMatchingIndexingMap(&opOperand);
234  AffineMap idMap =
235  AffineMap::getMultiDimIdentityMap(map.getNumDims(), op.getContext());
236  AffineMap shiftedIdMap = idMap.shiftDims(1, /*offset=*/reductionDimPos + 1);
237  return map.compose(shiftedIdMap).insertResult(reductionDim, reductionDimPos);
238 }
239 
240 /// Core rewrite implementation.
241 FailureOr<SplitReductionResult> mlir::linalg::splitReductionByScaling(
242  RewriterBase &b, LinalgOp op,
243  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
244  OpBuilder::InsertionGuard guard(b);
245  b.setInsertionPoint(op);
246 
247  // Matcher part, enforce preconditions.
248  SplitReductionOptions control = controlSplitReductionFn(op);
249  if (control.innerParallel)
250  return b.notifyMatchFailure(op, "innerParallel not supported");
251 
252  int64_t splitFactor = control.ratio;
253  unsigned insertSplitDimension = control.index;
254  if (splitFactor <= 1)
255  return b.notifyMatchFailure(op, "split factor needs to be greater than 1");
256 
258  op.getReductionDims(dims);
259  if (dims.empty())
260  return b.notifyMatchFailure(op, "needs at least 1 reduction dimension");
261 
262  unsigned reductionDimPos = dims[0];
263  SmallVector<int64_t> loopRanges = op.getStaticLoopRanges();
264  int64_t reductionDimSize = loopRanges[reductionDimPos];
265  if (reductionDimSize == ShapedType::kDynamic ||
266  reductionDimSize % splitFactor != 0 ||
267  insertSplitDimension >= loopRanges.size())
268  return b.notifyMatchFailure(
269  op, "first reduction dimension not divisible by split factor");
270 
271  SmallVector<Operation *> combinerOps;
272  if (!matchReduction(op.getRegionOutputArgs(), 0, combinerOps))
273  return b.notifyMatchFailure(op, "cannot match a reduction pattern");
274 
275  SmallVector<TypedAttr> neutralElements;
276  for (Operation *reductionOp : combinerOps) {
277  std::optional<TypedAttr> neutralElement =
278  arith::getNeutralElement(reductionOp);
279  if (!neutralElement.has_value())
280  return b.notifyMatchFailure(op, "cannot find neutral element.");
281  neutralElements.push_back(*neutralElement);
282  }
283  if (!llvm::all_of(neutralElements, [](Attribute attr) { return attr; }))
284  return b.notifyMatchFailure(op, "unknown reduction neutral");
285 
286  // TODO: relax this when multi-reduction support is available.
287  if (op.getNumDpsInits() != static_cast<int64_t>(neutralElements.size()))
288  return b.notifyMatchFailure(op, "expect one reduction per output");
289 
290  // Rewrite part.
291  // Step 1. Build the intermediate outputs filled with the proper
292  // neutralElements. Such outputs are of the same shape with an extra dimension
293  // inserted at `insertSplitDimension`.
294  //
295  // Consider a minimal example where `k` is reduced:
296  // O(i, j) += I(i, j, k)
297  // Assume i=3, j=5, k=128, splitFactor=16 and insertSplitDimension=0.
298  // The compute is rewritten as:
299  // a. O_i(kk, i, j) += I(i, j, 16 * k + kk)
300  // b. O(i, j) += O_i(kk, i, j)
301  // The intermediate tensor O_i is of shape (128/16)x3x5 == 8x3x5.
302  Location loc = op->getLoc();
303  MLIRContext *context = op.getContext();
304  // For now assume outputs are 1-1 with reduction neutralElements.
305  // TODO: generalize when multi-reduction support is available.
306  SmallVector<Value> newOutputs;
307  newOutputs.reserve(op.getNumDpsInits());
308  SmallVector<Operation *> emptyOrAllocTensorOps;
310  fillOps.reserve(op.getNumDpsInits());
311  for (auto it : llvm::zip(op.getDpsInitsMutable(), neutralElements)) {
312  Value rankedTensor = std::get<0>(it).get();
313  auto t = cast<RankedTensorType>(rankedTensor.getType());
314  RankedTensorType newT = RankedTensorType::Builder(t).insertDim(
315  reductionDimSize / splitFactor, insertSplitDimension);
316  SmallVector<Value> dims =
317  tensor::createDynamicDimValues(b, loc, rankedTensor);
318  Value emptyOrAllocTensor;
319  if (useAlloc) {
320  emptyOrAllocTensor =
321  b.create<bufferization::AllocTensorOp>(loc, newT, dims);
322  } else {
323  emptyOrAllocTensor = b.create<tensor::EmptyOp>(loc, newT.getShape(),
324  t.getElementType(), dims);
325  }
326  Value constantOp = b.create<arith::ConstantOp>(loc, std::get<1>(it));
327  fillOps.push_back(
328  b.create<linalg::FillOp>(op->getLoc(), constantOp, emptyOrAllocTensor));
329  newOutputs.push_back(fillOps.back().getResult(0));
330  emptyOrAllocTensorOps.push_back(emptyOrAllocTensor.getDefiningOp());
331  }
332 
333  // Step 2. Reindex / expand indexing maps.
334  // Reindex existing input indexings: k -> k * splitFactor + k'.
335  SmallVector<AffineMap> newMaps;
336  newMaps.reserve(op->getNumOperands() + 1);
337  for (OpOperand *o : op.getDpsInputOperands())
338  newMaps.push_back(scaleReductionDim(op, *o, reductionDimPos, splitFactor));
339  // Provision a new indexing for the shape-only tensor.
340  auto nDims = op.getNumLoops() + 1;
341  auto redDim = getAffineDimExpr(reductionDimPos, context);
342  auto redDimP1 = getAffineDimExpr(reductionDimPos + 1, context);
343  newMaps.push_back(AffineMap::get(nDims, 0, {redDim, redDimP1}, context));
344  // Expand existing output indexings.
345  // TODO: a subset of these may not reduce along reducePos and should be
346  // reindexed: k -> k * splitFactor + k', when multi-reduction support is
347  // available.
348  for (OpOperand &o : op.getDpsInitsMutable())
349  newMaps.push_back(insertParallelDim(op, o, reductionDimPos,
350  reductionDimSize / splitFactor));
351 
352  // Step 3. Handle operands.
353  // Compute the new input tensors.
354  SmallVector<Value> newInputs = op.getDpsInputs();
355  // Add a single shape-only tensor to carry the dimensions without resorting to
356  // more complex inversions.
357  newInputs.push_back(b.create<tensor::EmptyOp>(
358  loc, ArrayRef<int64_t>{reductionDimSize / splitFactor, splitFactor},
359  b.getIntegerType(1)));
360  // Output tensors are already good to go.
361 
362  // Step 4. Create the new op matching the original op with an extra parallel
363  // dimension.
364  auto iteratorTypes = op.getIteratorTypesArray();
365  iteratorTypes.insert(iteratorTypes.begin() + reductionDimPos,
366  utils::IteratorType::parallel);
367  GenericOp genericOp =
368  b.create<GenericOp>(loc, ValueRange(newOutputs).getTypes(), newInputs,
369  newOutputs, newMaps, iteratorTypes);
370  b.inlineRegionBefore(op->getRegion(0), genericOp.getRegion(),
371  genericOp.getRegion().begin());
372  genericOp.getRegion().front().insertArgument(reductionDimPos,
373  b.getIntegerType(1), loc);
374 
375  // Step 5. Create new reduction ops that only reduce the newly added
376  // dimensions from the previous op.
377  // For now assume outputs are 1-1 with reduction ops.
378  // TODO: a subset of these may not reduce in the first place and do not
379  // require a new op, when multi-reduction support is available.
380  // TODO: all results can be handled in a single GenericOp, when
381  // multi-reduction support is available.
382  SmallVector<LinalgOp> results;
383  for (auto it :
384  llvm::zip(genericOp->getResults(), op.getDpsInits(), combinerOps)) {
385  Value reindexedOutput = std::get<0>(it);
386  Value originalOutput = std::get<1>(it);
387  auto originalOutputType = cast<RankedTensorType>(originalOutput.getType());
388  Operation *combinerOp = std::get<2>(it);
389 
390  AffineMap map = b.getMultiDimIdentityMap(originalOutputType.getRank() + 1);
391  SmallVector<AffineMap> indexingMaps = {
392  map, map.dropResult(insertSplitDimension)};
393  SmallVector<utils::IteratorType> reductionIteratorTypes(
394  originalOutputType.getRank() + 1, utils::IteratorType::parallel);
395  reductionIteratorTypes[insertSplitDimension] =
396  utils::IteratorType::reduction;
397 
398  // clang-format off
399  auto reductionOp = b.create<GenericOp>(
400  loc,
401  originalOutputType,
402  reindexedOutput,
403  originalOutput,
404  indexingMaps,
405  reductionIteratorTypes,
406  [combinerOp](OpBuilder &b, Location loc, ValueRange bbArgs) {
407  Operation *clonedReductionOp = b.clone(*combinerOp);
408  clonedReductionOp->setOperand(0, bbArgs[0]);
409  clonedReductionOp->setOperand(1, bbArgs[1]);
410  b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
411  });
412  // clang-format on
413 
414  results.push_back(reductionOp);
415  }
416 
417  // TODO: extend when multi-reduction support is available.
418  assert(fillOps.size() == results.size() && results.size() == 1);
419  b.replaceOp(op, results.front()->getResults());
420  return SplitReductionResult{emptyOrAllocTensorOps.front(), fillOps.front(),
421  cast<LinalgOp>(genericOp.getOperation()),
422  results.front()};
423 }
424 
425 namespace {
426 
427 struct LinalgSplitReduction : public OpInterfaceRewritePattern<LinalgOp> {
428  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
429  LinalgSplitReduction(MLIRContext *context,
430  ControlSplitReductionFn controlSplitReductionFn,
431  bool useAlloc = false, PatternBenefit benefit = 1)
432  : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
433  controlSplitReductionFn(std::move(controlSplitReductionFn)),
434  useAlloc(useAlloc) {}
435 
436  LogicalResult matchAndRewrite(LinalgOp op,
437  PatternRewriter &rewriter) const override {
438  return splitReduction(rewriter, op, controlSplitReductionFn, useAlloc);
439  }
440 
441 private:
442  ControlSplitReductionFn controlSplitReductionFn;
443  bool useAlloc;
444 };
445 
446 } // namespace
447 
450  const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc) {
451  patterns.add<LinalgSplitReduction>(patterns.getContext(),
452  controlSplitReductionFn, useAlloc);
453 }
static AffineMap scaleReductionDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t reductionRatio)
Rewrite f(i, j, k, ...) into f(i, j, k * ratio + kk, ...) TODO: Additional pattern to rewrite f(i,...
static AffineMap insertParallelDim(LinalgOp op, OpOperand &opOperand, unsigned reductionDimPos, int64_t size)
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
unsigned getDimPosition(unsigned idx) const
Extracts the position of the dimensional expression at the given result, when the caller knows it is ...
Definition: AffineMap.cpp:415
static AffineMap getMultiDimIdentityMap(unsigned numDims, MLIRContext *context)
Returns an AffineMap with 'numDims' identity result dim exprs.
Definition: AffineMap.cpp:334
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
AffineMap shiftDims(unsigned shift, unsigned offset=0) const
Replace dims[offset ...
Definition: AffineMap.h:267
AffineMap insertResult(AffineExpr expr, unsigned pos) const
Returns a new AffineMap with the same number of dims and symbols and an extra result inserted at pos.
Definition: AffineMap.h:315
AffineMap dropResult(int64_t pos) const
Returns a new AffineMap with the same number of dims and symbols and one less result at pos,...
Definition: AffineMap.h:293
unsigned getNumDims() const
Definition: AffineMap.cpp:394
unsigned getNumResults() const
Definition: AffineMap.cpp:402
AffineMap replace(AffineExpr expr, AffineExpr replacement, unsigned numResultDims, unsigned numResultSyms) const
Sparse replace method.
Definition: AffineMap.cpp:515
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:556
Attributes are known-constant values of operations.
Definition: Attributes.h:25
AffineMap getMultiDimIdentityMap(unsigned rank)
Definition: Builders.cpp:427
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:111
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:404
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:357
This class helps build Operations.
Definition: Builders.h:216
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition: Builders.cpp:588
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:407
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
This class represents an operand of an operation.
Definition: Value.h:267
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
void setOperand(unsigned idx, Value value)
Definition: Operation.h:351
Operation * clone(IRMapping &mapper, CloneOptions options=CloneOptions::all())
Create a deep copy of this operation, remapping any operands that use values outside of the operation...
Definition: Operation.cpp:717
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
Definition: PatternMatch.h:34
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:261
Builder & insertDim(int64_t val, unsigned pos)
Insert a val into shape @pos.
Definition: BuiltinTypes.h:295
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void inlineRegionBefore(Region &region, Region &parent, Region::iterator before)
Move the blocks that belong to "region" before the given position in another region "parent".
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
std::optional< TypedAttr > getNeutralElement(Operation *op)
Return the identity numeric value associated to the give op.
Definition: ArithOps.cpp:2550
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
void populateSplitReductionPattern(RewritePatternSet &patterns, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Patterns to apply splitReduction below.
FailureOr< SplitReductionResult > splitReduction(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
std::function< SplitReductionOptions(LinalgOp op)> ControlSplitReductionFn
Function signature to control reduction splitting.
Definition: Transforms.h:442
FailureOr< SplitReductionResult > splitReductionByScaling(RewriterBase &b, LinalgOp op, const ControlSplitReductionFn &controlSplitReductionFn, bool useAlloc=false)
Scaling-based implementation of the split reduction transformation.
SmallVector< Value > createDynamicDimValues(OpBuilder &b, Location loc, Value rankedTensor)
Definition: Utils.cpp:63
Include the generated interface declarations.
Value matchReduction(ArrayRef< BlockArgument > iterCarriedArgs, unsigned redPos, SmallVectorImpl< Operation * > &combinerOps)
Utility to match a generic reduction given a list of iteration-carried arguments, iterCarriedArgs and...
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:617
OpInterfaceRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting a...
Definition: PatternMatch.h:373
Split Reduction options.
Definition: Transforms.h:427
Apply transformation to split the single linalg op reduction into a parallel and reduction dimension.
Definition: Transforms.h:1028