MLIR 22.0.0git
DropUnitDims.cpp
Go to the documentation of this file.
1//===- DropUnitDims.cpp - Pass to drop use of unit-extent for broadcasting ===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements patterns/pass to remove usage of unit-extent dimensions
10// to specify broadcasting in favor of more canonical representation of the
11// computation
12//
13//===----------------------------------------------------------------------===//
14
16
26#include "mlir/IR/AffineExpr.h"
27#include "mlir/IR/AffineMap.h"
31#include "llvm/Support/Debug.h"
32
33namespace mlir {
34#define GEN_PASS_DEF_LINALGFOLDUNITEXTENTDIMSPASS
35#include "mlir/Dialect/Linalg/Passes.h.inc"
36} // namespace mlir
37
38#define DEBUG_TYPE "linalg-drop-unit-dims"
39
40using namespace mlir;
41using namespace mlir::linalg;
42
43namespace {
44/// Pattern to move init operands to ins when all the loops are parallel and
45/// blockArgument corresponding to init is used in the region. This is a fix-up
46/// when unit reduction dimensions are all folded away. In this context, it
47/// becomes a elementwise generic op. E.g., it converts
48///
49/// %0 = tensor.empty() : tensor<1x1xf32>
50/// %1 = linalg.fill
51/// ins(%cst : f32)
52/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
53/// %2 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
54/// affine_map<(d0) -> (0, d0)>],
55/// iterator_types = ["parallel"]}
56/// ins(%arg0 : tensor<1x?x1x1xf32>)
57/// outs(%1 : tensor<1x1xf32>) {
58/// ^bb0(%in: f32, %out: f32):
59/// %3 = arith.addf %in, %out : f32
60/// linalg.yield %3 : f32
61/// } -> tensor<1x1xf32>
62///
63/// into
64///
65/// %0 = tensor.empty() : tensor<1x1xf32>
66/// %1 = linalg.fill
67/// ins(%cst : f32)
68/// outs(%0 : tensor<1x1xf32>) -> tensor<1x1xf32>
69/// %2 = tensor.empty() : tensor<1x1xf32>
70/// %3 = linalg.generic {indexing_maps = [affine_map<(d0) -> (0, d0, 0, 0)>,
71/// affine_map<(d0) -> (0, d0)>,
72/// affine_map<(d0) -> (0, d0)>],
73/// iterator_types = ["parallel"]}
74/// ins(%arg0, %1 : tensor<1x?x1x1xf32>, tensor<1x1xf32>)
75/// outs(%2 : tensor<1x1xf32>) {
76/// ^bb0(%in: f32, %in_0: f32, %out: f32):
77/// %4 = arith.addf %in, %in_0 : f32
78/// linalg.yield %4 : f32
79/// } -> tensor<1x1xf32>
80struct MoveInitOperandsToInput : public OpRewritePattern<GenericOp> {
82 LogicalResult matchAndRewrite(GenericOp genericOp,
83 PatternRewriter &rewriter) const override {
84 if (!genericOp.hasPureTensorSemantics())
85 return failure();
86 if (genericOp.getNumParallelLoops() != genericOp.getNumLoops())
87 return failure();
88
89 auto outputOperands = genericOp.getDpsInitsMutable();
90 SetVector<OpOperand *> candidates;
91 for (OpOperand &op : outputOperands) {
92 if (genericOp.getMatchingBlockArgument(&op).use_empty())
93 continue;
94 candidates.insert(&op);
95 }
96
97 if (candidates.empty())
98 return failure();
99
100 // Compute the modified indexing maps.
101 int64_t origNumInput = genericOp.getNumDpsInputs();
102 SmallVector<Value> newInputOperands = genericOp.getDpsInputs();
103 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
104 SmallVector<AffineMap> newIndexingMaps;
105 newIndexingMaps.append(indexingMaps.begin(),
106 std::next(indexingMaps.begin(), origNumInput));
107 for (OpOperand *op : candidates) {
108 newInputOperands.push_back(op->get());
109 newIndexingMaps.push_back(genericOp.getMatchingIndexingMap(op));
110 }
111 newIndexingMaps.append(std::next(indexingMaps.begin(), origNumInput),
112 indexingMaps.end());
113
114 Location loc = genericOp.getLoc();
115 SmallVector<Value> newOutputOperands =
116 llvm::to_vector(genericOp.getDpsInits());
117 for (OpOperand *op : candidates) {
118 OpBuilder::InsertionGuard guard(rewriter);
119 rewriter.setInsertionPointAfterValue(op->get());
120 auto elemType = cast<ShapedType>(op->get().getType()).getElementType();
121 auto empty = tensor::EmptyOp::create(
122 rewriter, loc, tensor::getMixedSizes(rewriter, loc, op->get()),
123 elemType);
124
125 unsigned start = genericOp.getDpsInits().getBeginOperandIndex();
126 newOutputOperands[op->getOperandNumber() - start] = empty.getResult();
127 }
128
129 auto newOp = GenericOp::create(
130 rewriter, loc, genericOp.getResultTypes(), newInputOperands,
131 newOutputOperands, newIndexingMaps, genericOp.getIteratorTypesArray(),
132 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
133
134 OpBuilder::InsertionGuard guard(rewriter);
135 Region &region = newOp.getRegion();
136 Block *block = rewriter.createBlock(&region);
137 IRMapping mapper;
138 for (auto bbarg : genericOp.getRegionInputArgs())
139 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
140
141 for (OpOperand *op : candidates) {
142 BlockArgument bbarg = genericOp.getMatchingBlockArgument(op);
143 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
144 }
145
146 for (OpOperand &op : outputOperands) {
147 BlockArgument bbarg = genericOp.getMatchingBlockArgument(&op);
148 if (candidates.count(&op))
149 block->addArgument(bbarg.getType(), loc);
150 else
151 mapper.map(bbarg, block->addArgument(bbarg.getType(), loc));
152 }
153
154 for (auto &op : genericOp.getBody()->getOperations()) {
155 rewriter.clone(op, mapper);
156 }
157 rewriter.replaceOp(genericOp, newOp.getResults());
158
159 return success();
160 }
161};
162} // namespace
163
164//===---------------------------------------------------------------------===//
165// Drop loops that are unit-extents within Linalg operations.
166//===---------------------------------------------------------------------===//
167
168/// Implements a pass that canonicalizes the uses of unit-extent dimensions for
169/// broadcasting. For example,
170///
171/// ```mlir
172/// #accesses = [
173/// affine_map<(d0, d1) -> (0, d1)>,
174/// affine_map<(d0, d1) -> (d0, 0)>,
175/// affine_map<(d0, d1) -> (d0, d1)>
176/// ]
177///
178/// #trait = {
179/// indexing_maps = #accesses,
180/// iterator_types = ["parallel", "parallel"],
181/// library_call = "some_external_fn"
182/// }
183///
184/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
185/// tensor<5x5xf32>
186/// {
187/// %0 = linalg.tensor_reshape %arg0 [affine_map<(d0, d1) -> (d0, d1)>] :
188/// tensor<5xf32> into tensor<1x5xf32>
189/// %1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
190/// tensor<5xf32> into tensor<5x1xf32>
191/// %2 = linalg.generic #trait %0, %1 {
192/// ^bb0(%arg2: f32, %arg3: f32):
193/// %3 = arith.addf %arg2, %arg3 : f32
194/// linalg.yield %3 : f32
195/// } : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
196/// return %2 : tensor<5x5xf32>
197/// }
198///
199/// would canonicalize to
200///
201/// ```mlir
202/// #accesses = [
203/// affine_map<(d0, d1) -> (d1)>,
204/// affine_map<(d0, d1) -> (d0)>,
205/// affine_map<(d0, d1) -> (d0, d1)>
206/// ]
207///
208/// #trait = {
209/// indexing_maps = #accesses,
210/// iterator_types = ["parallel", "parallel"],
211/// library_call = "some_external_fn"
212/// }
213///
214/// func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) ->
215/// tensor<5x5xf32>
216/// {
217/// %0 = linalg.generic #trait %arg0, %arg1 {
218/// ^bb0(%arg2: f32, %arg3: f32):
219/// %3 = arith.addf %arg2, %arg3 : f32
220/// linalg.yield %3 : f32
221/// } : tensor<5xf32>, tensor<5xf32> -> tensor<5x5xf32>
222/// return %0 : tensor<5x5xf32>
223/// }
224
225/// Update the index accesses of linalg operations having index semantics.
226static void
227replaceUnitDimIndexOps(GenericOp genericOp,
228 const llvm::SmallDenseSet<unsigned> &unitDims,
229 RewriterBase &rewriter) {
230 for (IndexOp indexOp :
231 llvm::make_early_inc_range(genericOp.getBody()->getOps<IndexOp>())) {
232 OpBuilder::InsertionGuard guard(rewriter);
233 rewriter.setInsertionPoint(indexOp);
234 if (unitDims.count(indexOp.getDim()) != 0) {
235 rewriter.replaceOpWithNewOp<arith::ConstantIndexOp>(indexOp, 0);
236 } else {
237 // Update the dimension of the index operation if needed.
238 unsigned droppedDims = llvm::count_if(
239 unitDims, [&](unsigned dim) { return dim < indexOp.getDim(); });
240 if (droppedDims != 0)
241 rewriter.replaceOpWithNewOp<IndexOp>(indexOp,
242 indexOp.getDim() - droppedDims);
243 }
244 }
245}
246
247/// Expand the given `value` so that the type matches the type of `origDest`.
248/// The `reassociation` is used when `rankReductionStrategy` is set to
249/// `RankReductionStrategy::ReassociativeReshape`.
250static Value
252 ArrayRef<ReassociationIndices> reassociation,
253 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
254 // There are no results for memref outputs.
255 auto origResultType = cast<RankedTensorType>(origDest.getType());
256 if (rankReductionStrategy ==
258 unsigned rank = origResultType.getRank();
259 SmallVector<OpFoldResult> offsets(rank, rewriter.getIndexAttr(0));
261 tensor::getMixedSizes(rewriter, loc, origDest);
262 SmallVector<OpFoldResult> strides(rank, rewriter.getIndexAttr(1));
263 return rewriter.createOrFold<tensor::InsertSliceOp>(
264 loc, result, origDest, offsets, sizes, strides);
265 }
266
267 assert(rankReductionStrategy ==
269 "unknown rank reduction strategy");
270 return tensor::ExpandShapeOp::create(rewriter, loc, origResultType, result,
271 reassociation)
272 .getResult();
273}
274
275/// Collapse the given `value` so that the type matches the type of
276/// `origOutput`. The `reassociation` is used when `rankReductionStrategy` is
277/// set to `RankReductionStrategy::ReassociativeReshape`.
279 RewriterBase &rewriter, Location loc, Value operand,
280 ArrayRef<int64_t> targetShape, ArrayRef<ReassociationIndices> reassociation,
281 ControlDropUnitDims::RankReductionStrategy rankReductionStrategy) {
282 if (auto memrefType = dyn_cast<MemRefType>(operand.getType())) {
283 if (rankReductionStrategy ==
285 FailureOr<Value> rankReducingExtract =
286 memref::SubViewOp::rankReduceIfNeeded(rewriter, loc, operand,
287 targetShape);
288 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
289 return *rankReducingExtract;
290 }
291
292 assert(
293 rankReductionStrategy ==
295 "unknown rank reduction strategy");
296 MemRefLayoutAttrInterface layout;
297 auto targetType = MemRefType::get(targetShape, memrefType.getElementType(),
298 layout, memrefType.getMemorySpace());
299 return memref::CollapseShapeOp::create(rewriter, loc, targetType, operand,
300 reassociation);
301 }
302 if (auto tensorType = dyn_cast<RankedTensorType>(operand.getType())) {
303 if (rankReductionStrategy ==
305 FailureOr<Value> rankReducingExtract =
306 tensor::ExtractSliceOp::rankReduceIfNeeded(rewriter, loc, operand,
307 targetShape);
308 assert(succeeded(rankReducingExtract) && "not a unit-extent collapse");
309 return *rankReducingExtract;
310 }
311
312 assert(
313 rankReductionStrategy ==
315 "unknown rank reduction strategy");
316 auto targetType =
317 RankedTensorType::get(targetShape, tensorType.getElementType());
318 return tensor::CollapseShapeOp::create(rewriter, loc, targetType, operand,
319 reassociation);
320 }
321 llvm_unreachable("unsupported operand type");
322}
323
324/// Compute the modified metadata for an operands of operation
325/// whose unit dims are being dropped. Return the new indexing map
326/// to use, the shape of the operand in the replacement op
327/// and the `reassocation` to use to go from original operand shape
328/// to modified operand shape.
335 MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand,
336 llvm::SmallDenseMap<unsigned, unsigned> &oldDimsToNewDimsMap,
337 ArrayRef<AffineExpr> dimReplacements) {
339 ReassociationIndices reassociationGroup;
340 SmallVector<AffineExpr> newIndexExprs;
341 AffineMap indexingMap = op.getMatchingIndexingMap(opOperand);
342 SmallVector<int64_t> operandShape = op.getStaticOperandShape(opOperand);
343 ArrayRef<AffineExpr> exprs = indexingMap.getResults();
344
345 auto isUnitDim = [&](unsigned dim) {
346 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[dim])) {
347 unsigned oldPosition = dimExpr.getPosition();
348 return !oldDimsToNewDimsMap.count(oldPosition) &&
349 (operandShape[dim] == 1);
350 }
351 // Handle the other case where the shape is 1, and is accessed using a
352 // constant 0.
353 if (operandShape[dim] == 1) {
354 auto constAffineExpr = dyn_cast<AffineConstantExpr>(exprs[dim]);
355 return constAffineExpr && constAffineExpr.getValue() == 0;
356 }
357 return false;
358 };
359
360 unsigned dim = 0;
361 while (dim < operandShape.size() && isUnitDim(dim))
362 reassociationGroup.push_back(dim++);
363 while (dim < operandShape.size()) {
364 assert(!isUnitDim(dim) && "expected non unit-extent");
365 reassociationGroup.push_back(dim);
366 AffineExpr newExpr = exprs[dim].replaceDims(dimReplacements);
367 newIndexExprs.push_back(newExpr);
368 info.targetShape.push_back(operandShape[dim]);
369 ++dim;
370 // Fold all following dimensions that are unit-extent.
371 while (dim < operandShape.size() && isUnitDim(dim)) {
372 reassociationGroup.push_back(dim++);
373 }
374 info.reassociation.push_back(reassociationGroup);
375 reassociationGroup.clear();
376 }
377 info.indexMap =
378 AffineMap::get(oldDimsToNewDimsMap.size(), indexingMap.getNumSymbols(),
379 newIndexExprs, context);
380 return info;
381}
382
383FailureOr<DropUnitDimsResult>
384linalg::dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op,
385 const DroppedUnitDimsBuilder &droppedUnitDimsBuilder,
387 auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op.getOperation());
388 if (!dpsOp) {
389 return rewriter.notifyMatchFailure(
390 op, "op should implement DestinationStyleOpInterface");
391 }
392
393 SmallVector<AffineMap> indexingMaps = op.getIndexingMapsArray();
394 if (indexingMaps.empty())
395 return failure();
396
397 // 1. Check if any of the iteration dimensions are unit-trip count. They will
398 // end up being unit-trip count if they are used to index into a unit-dim
399 // tensor/memref.
400 AffineMap invertedMap =
401 inversePermutation(concatAffineMaps(indexingMaps, rewriter.getContext()));
402 if (!invertedMap) {
403 return rewriter.notifyMatchFailure(op,
404 "invalid indexing maps for operation");
405 }
406
407 SmallVector<int64_t> allShapesSizes;
408 for (OpOperand &opOperand : op->getOpOperands())
409 llvm::append_range(allShapesSizes, op.getStaticOperandShape(&opOperand));
410
411 // 1a. Get the allowed list of dimensions to drop from the `options`.
412 SmallVector<unsigned> allowedUnitDims = options.controlFn(op);
413 if (allowedUnitDims.empty()) {
414 return rewriter.notifyMatchFailure(
415 op, "control function returns no allowed unit dims to prune");
416 }
417 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
418 allowedUnitDims.end());
419 llvm::SmallDenseSet<unsigned> unitDims;
420 for (const auto &expr : enumerate(invertedMap.getResults())) {
421 if (AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(expr.value())) {
422 if (allShapesSizes[dimExpr.getPosition()] == 1 &&
423 unitDimsFilter.count(expr.index()))
424 unitDims.insert(expr.index());
425 }
426 }
427
428 // 2. Compute the new loops of the modified op by dropping the one-trip
429 // count loops.
430 llvm::SmallDenseMap<unsigned, unsigned> oldDimToNewDimMap;
431 SmallVector<AffineExpr> dimReplacements;
432 unsigned newDims = 0;
433 for (auto index : llvm::seq<int64_t>(op.getStaticLoopRanges().size())) {
434 if (unitDims.count(index)) {
435 dimReplacements.push_back(
436 getAffineConstantExpr(0, rewriter.getContext()));
437 } else {
438 oldDimToNewDimMap[index] = newDims;
439 dimReplacements.push_back(
440 getAffineDimExpr(newDims, rewriter.getContext()));
441 newDims++;
442 }
443 }
444
445 // 3. For each of the operands, find the
446 // - modified affine map to use.
447 // - shape of the operands after the unit-dims are dropped.
448 // - the reassociation indices used to convert from the original
449 // operand type to modified operand (needed only when using reshapes
450 // for rank reduction strategy)
451 // Note that the indexing maps might need changing even if there are no
452 // unit dimensions that are dropped to handle cases where `0` is used to
453 // access a unit-extent tensor. Consider moving this out of this specific
454 // transformation as a stand-alone transformation. Kept here right now due
455 // to legacy.
456 SmallVector<AffineMap> newIndexingMaps;
459 SmallVector<bool> collapsed;
460 auto hasCollapsibleType = [](OpOperand &operand) {
461 Type operandType = operand.get().getType();
462 if (auto memrefOperandType = dyn_cast_or_null<MemRefType>(operandType)) {
463 return memrefOperandType.getLayout().isIdentity();
464 }
465 if (auto tensorOperandType = dyn_cast<RankedTensorType>(operandType)) {
466 return tensorOperandType.getEncoding() == nullptr;
467 }
468 return false;
469 };
470 for (OpOperand &opOperand : op->getOpOperands()) {
471 auto indexingMap = op.getMatchingIndexingMap(&opOperand);
472 SmallVector<int64_t> shape = op.getStaticOperandShape(&opOperand);
473 if (!hasCollapsibleType(opOperand)) {
474 AffineMap newIndexingMap = indexingMap.replaceDimsAndSymbols(
475 dimReplacements, ArrayRef<AffineExpr>{}, oldDimToNewDimMap.size(), 0);
476 newIndexingMaps.push_back(newIndexingMap);
477 targetShapes.push_back(llvm::to_vector(shape));
478 collapsed.push_back(false);
479 reassociations.push_back({});
480 continue;
481 }
482 auto replacementInfo =
483 dropUnitExtentFromOperandMetadata(rewriter.getContext(), op, &opOperand,
484 oldDimToNewDimMap, dimReplacements);
485 reassociations.push_back(replacementInfo.reassociation);
486 newIndexingMaps.push_back(replacementInfo.indexMap);
487 targetShapes.push_back(replacementInfo.targetShape);
488 collapsed.push_back(!(replacementInfo.indexMap.getNumResults() ==
489 indexingMap.getNumResults()));
490 }
491
492 // Abort if the indexing maps of the result operation are not invertible
493 // (i.e. not legal) or if no dimension was reduced.
494 if (newIndexingMaps == indexingMaps ||
496 concatAffineMaps(newIndexingMaps, rewriter.getContext())))
497 return failure();
498
499 Location loc = op.getLoc();
500 // 4. For each of the operands, collapse the operand to convert
501 // from original shape to shape in the modified operation if needed,
502 // either through use of reshapes or rank-reducing slices as
503 // specified in `options`.
504 SmallVector<Value> newOperands;
505 for (OpOperand &opOperand : op->getOpOperands()) {
506 int64_t idx = opOperand.getOperandNumber();
507 if (!collapsed[idx]) {
508 newOperands.push_back(opOperand.get());
509 continue;
510 }
511 newOperands.push_back(collapseValue(rewriter, loc, opOperand.get(),
512 targetShapes[idx], reassociations[idx],
513 options.rankReductionStrategy));
514 }
515
516 IndexingMapOpInterface replacementOp = droppedUnitDimsBuilder(
517 loc, rewriter, op, newOperands, newIndexingMaps, unitDims);
518
519 // 6. If any result type changes, insert a reshape/slice to convert from the
520 // original type to the new type.
521 SmallVector<Value> resultReplacements;
522 for (auto [index, result] : llvm::enumerate(replacementOp->getResults())) {
523 unsigned opOperandIndex = index + dpsOp.getNumDpsInputs();
524 Value origDest = dpsOp.getDpsInitOperand(index)->get();
525 if (!collapsed[opOperandIndex]) {
526 resultReplacements.push_back(result);
527 continue;
528 }
529 Value expandedValue = expandValue(rewriter, loc, result, origDest,
530 reassociations[opOperandIndex],
531 options.rankReductionStrategy);
532 resultReplacements.push_back(expandedValue);
533 }
534
535 return DropUnitDimsResult{replacementOp, resultReplacements};
536}
537
538FailureOr<DropUnitDimsResult>
539linalg::dropUnitDims(RewriterBase &rewriter, GenericOp genericOp,
541
543 [](Location loc, OpBuilder &b, IndexingMapOpInterface op,
544 ArrayRef<Value> newOperands, ArrayRef<AffineMap> newIndexingMaps,
545 const llvm::SmallDenseSet<unsigned> &droppedDims)
546 -> IndexingMapOpInterface {
547 auto genericOp = cast<GenericOp>(op);
548 // Compute the iterator types of the modified op by dropping the one-trip
549 // count loops.
550 SmallVector<utils::IteratorType> newIteratorTypes;
551 for (auto [index, attr] :
552 llvm::enumerate(genericOp.getIteratorTypesArray())) {
553 if (!droppedDims.count(index))
554 newIteratorTypes.push_back(attr);
555 }
556
557 // Create the `linalg.generic` operation with the new operands,
558 // indexing maps, iterator types and result types.
559 ArrayRef<Value> newInputs =
560 ArrayRef<Value>(newOperands).take_front(genericOp.getNumDpsInputs());
561 ArrayRef<Value> newOutputs =
562 ArrayRef<Value>(newOperands).take_back(genericOp.getNumDpsInits());
563 SmallVector<Type> resultTypes;
564 resultTypes.reserve(genericOp.getNumResults());
565 for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
566 resultTypes.push_back(newOutputs[i].getType());
567 GenericOp replacementOp =
568 GenericOp::create(b, loc, resultTypes, newInputs, newOutputs,
569 newIndexingMaps, newIteratorTypes);
570 b.cloneRegionBefore(genericOp.getRegion(), replacementOp.getRegion(),
571 replacementOp.getRegion().begin());
572 // 5a. Replace `linalg.index` operations that refer to the dropped unit
573 // dimensions.
574 IRRewriter rewriter(b);
575 replaceUnitDimIndexOps(replacementOp, droppedDims, rewriter);
576
577 return replacementOp;
578 };
579
580 return dropUnitDims(rewriter, genericOp, build, options);
581}
582
583namespace {
584struct DropUnitDims : public OpRewritePattern<GenericOp> {
585 DropUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
586 PatternBenefit benefit = 1)
587 : OpRewritePattern(context, benefit), options(std::move(options)) {}
588
589 LogicalResult matchAndRewrite(GenericOp genericOp,
590 PatternRewriter &rewriter) const override {
591 FailureOr<DropUnitDimsResult> result =
592 dropUnitDims(rewriter, genericOp, options);
593 if (failed(result)) {
594 return failure();
595 }
596 rewriter.replaceOp(genericOp, result->replacements);
597 return success();
598 }
599
600private:
601 ControlDropUnitDims options;
602};
603} // namespace
604
605//===---------------------------------------------------------------------===//
606// Drop dimensions that are unit-extents within tensor operations.
607//===---------------------------------------------------------------------===//
608
609namespace {
610struct DropPadUnitDims : public OpRewritePattern<tensor::PadOp> {
611 DropPadUnitDims(MLIRContext *context, ControlDropUnitDims options = {},
612 PatternBenefit benefit = 1)
613 : OpRewritePattern(context, benefit), options(std::move(options)) {}
614
615 LogicalResult matchAndRewrite(tensor::PadOp padOp,
616 PatternRewriter &rewriter) const override {
617 // 1a. Get the allowed list of dimensions to drop from the `options`.
618 SmallVector<unsigned> allowedUnitDims = options.controlFn(padOp);
619 if (allowedUnitDims.empty()) {
620 return rewriter.notifyMatchFailure(
621 padOp, "control function returns no allowed unit dims to prune");
622 }
623
624 if (padOp.getSourceType().getEncoding()) {
625 return rewriter.notifyMatchFailure(
626 padOp, "cannot collapse dims of tensor with encoding");
627 }
628
629 // Fail for non-constant padding values. The body of the pad could
630 // depend on the padding indices and/or properties of the padded
631 // tensor so for now we fail.
632 // TODO: Support non-constant padding values.
633 Value paddingVal = padOp.getConstantPaddingValue();
634 if (!paddingVal) {
635 return rewriter.notifyMatchFailure(
636 padOp, "unimplemented: non-constant padding value");
637 }
638
639 ArrayRef<int64_t> sourceShape = padOp.getSourceType().getShape();
640 ArrayRef<int64_t> resultShape = padOp.getResultType().getShape();
641 int64_t padRank = sourceShape.size();
642
643 auto isStaticZero = [](OpFoldResult f) {
644 return getConstantIntValue(f) == 0;
645 };
646
647 llvm::SmallDenseSet<unsigned> unitDimsFilter(allowedUnitDims.begin(),
648 allowedUnitDims.end());
649 llvm::SmallDenseSet<unsigned> unitDims;
650 SmallVector<int64_t> newShape;
651 SmallVector<int64_t> newResultShape;
652 SmallVector<OpFoldResult> newLowPad;
653 SmallVector<OpFoldResult> newHighPad;
654 for (const auto [dim, size, outSize, low, high] : zip_equal(
655 llvm::seq(static_cast<int64_t>(0), padRank), sourceShape,
656 resultShape, padOp.getMixedLowPad(), padOp.getMixedHighPad())) {
657 if (unitDimsFilter.contains(dim) && size == 1 && isStaticZero(low) &&
658 isStaticZero(high)) {
659 unitDims.insert(dim);
660 } else {
661 newShape.push_back(size);
662 newResultShape.push_back(outSize);
663 newLowPad.push_back(low);
664 newHighPad.push_back(high);
665 }
666 }
667
668 if (unitDims.empty()) {
669 return rewriter.notifyMatchFailure(padOp, "no unit dims to collapse");
670 }
671
672 ReassociationIndices reassociationGroup;
673 SmallVector<ReassociationIndices> reassociationMap;
674 int64_t dim = 0;
675 while (dim < padRank && unitDims.contains(dim))
676 reassociationGroup.push_back(dim++);
677 while (dim < padRank) {
678 assert(!unitDims.contains(dim) && "expected non unit-extent");
679 reassociationGroup.push_back(dim);
680 dim++;
681 // Fold all following dimensions that are unit-extent.
682 while (dim < padRank && unitDims.contains(dim))
683 reassociationGroup.push_back(dim++);
684 reassociationMap.push_back(reassociationGroup);
685 reassociationGroup.clear();
686 }
687
688 Value collapsedSource =
689 collapseValue(rewriter, padOp.getLoc(), padOp.getSource(), newShape,
690 reassociationMap, options.rankReductionStrategy);
691
692 auto newResultType = RankedTensorType::get(
693 newResultShape, padOp.getResultType().getElementType());
694 auto newPadOp = tensor::PadOp::create(
695 rewriter, padOp.getLoc(), /*result=*/newResultType, collapsedSource,
696 newLowPad, newHighPad, paddingVal, padOp.getNofold());
697
698 Value dest = padOp.getResult();
699 if (options.rankReductionStrategy ==
700 ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
701 SmallVector<OpFoldResult> expandedSizes;
702 int64_t numUnitDims = 0;
703 for (auto dim : llvm::seq(static_cast<int64_t>(0), padRank)) {
704 if (unitDims.contains(dim)) {
705 expandedSizes.push_back(rewriter.getIndexAttr(1));
706 numUnitDims++;
707 continue;
708 }
709 expandedSizes.push_back(tensor::getMixedSize(
710 rewriter, padOp.getLoc(), newPadOp, dim - numUnitDims));
711 }
712 dest = tensor::EmptyOp::create(rewriter, padOp.getLoc(), expandedSizes,
713 padOp.getResultType().getElementType());
714 }
715
716 Value expandedValue =
717 expandValue(rewriter, padOp.getLoc(), newPadOp.getResult(), dest,
718 reassociationMap, options.rankReductionStrategy);
719 rewriter.replaceOp(padOp, expandedValue);
720 return success();
721 }
723private:
726} // namespace
728namespace {
729/// Convert `extract_slice` operations to rank-reduced versions.
730struct RankReducedExtractSliceOp
731 : public OpRewritePattern<tensor::ExtractSliceOp> {
732 using OpRewritePattern<tensor::ExtractSliceOp>::OpRewritePattern;
734 LogicalResult matchAndRewrite(tensor::ExtractSliceOp sliceOp,
735 PatternRewriter &rewriter) const override {
736 RankedTensorType resultType = sliceOp.getType();
737 SmallVector<OpFoldResult> targetShape;
738 for (auto size : resultType.getShape())
739 targetShape.push_back(rewriter.getIndexAttr(size));
740 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
741 if (!reassociation ||
742 reassociation->size() == static_cast<size_t>(resultType.getRank()))
743 return failure();
745 SmallVector<OpFoldResult> offsets = sliceOp.getMixedOffsets();
746 SmallVector<OpFoldResult> strides = sliceOp.getMixedStrides();
747 SmallVector<OpFoldResult> sizes = sliceOp.getMixedSizes();
748 auto rankReducedType = cast<RankedTensorType>(
749 tensor::ExtractSliceOp::inferCanonicalRankReducedResultType(
750 reassociation->size(), sliceOp.getSourceType(), offsets, sizes,
751 strides));
752
753 Location loc = sliceOp.getLoc();
754 Value newSlice = tensor::ExtractSliceOp::create(
755 rewriter, loc, rankReducedType, sliceOp.getSource(), offsets, sizes,
756 strides);
757 rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
758 sliceOp, resultType, newSlice, *reassociation);
759 return success();
760 }
761};
762
763/// Convert `insert_slice` operations to rank-reduced versions.
764/// This patterns works with both InsertSliceOp and ParallelInsertSliceOp.
765template <typename InsertOpTy>
766struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
767 using OpRewritePattern<InsertOpTy>::OpRewritePattern;
768
769 LogicalResult matchAndRewrite(InsertOpTy insertSliceOp,
770 PatternRewriter &rewriter) const override {
771 RankedTensorType sourceType = insertSliceOp.getSourceType();
773 for (auto size : sourceType.getShape())
774 targetShape.push_back(rewriter.getIndexAttr(size));
775 auto reassociation = getReassociationMapForFoldingUnitDims(targetShape);
776 if (!reassociation ||
777 reassociation->size() == static_cast<size_t>(sourceType.getRank()))
778 return failure();
779
780 Location loc = insertSliceOp.getLoc();
781 tensor::CollapseShapeOp reshapedSource;
783 OpBuilder::InsertionGuard g(rewriter);
784 // The only difference between InsertSliceOp and ParallelInsertSliceOp
785 // is the insertion point is just before the ParallelCombiningOp in the
786 // parallel case.
787 if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
788 rewriter.setInsertionPoint(insertSliceOp->getParentOp());
789 reshapedSource = tensor::CollapseShapeOp::create(
790 rewriter, loc, insertSliceOp.getSource(), *reassociation);
791 }
792 rewriter.replaceOpWithNewOp<InsertOpTy>(
793 insertSliceOp, reshapedSource, insertSliceOp.getDest(),
794 insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(),
795 insertSliceOp.getMixedStrides());
796 return success();
797 }
798};
799} // namespace
800
801/// Patterns that are used to canonicalize the use of unit-extent dims for
802/// broadcasting.
803static void
806 auto *context = patterns.getContext();
807 patterns.add<DropUnitDims>(context, options);
808 patterns.add<DropPadUnitDims>(context, options);
809 // TODO: Patterns unrelated to unit dim folding should be factored out.
810 patterns.add<RankReducedExtractSliceOp,
811 RankReducedInsertSliceOp<tensor::InsertSliceOp>,
812 RankReducedInsertSliceOp<tensor::ParallelInsertSliceOp>>(
813 context);
814 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
815 tensor::CollapseShapeOp::getCanonicalizationPatterns(patterns, context);
816 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
817 tensor::ExpandShapeOp::getCanonicalizationPatterns(patterns, context);
821}
822
823static void
826 auto *context = patterns.getContext();
827 patterns.add<DropUnitDims>(context, options);
828 patterns.add<DropPadUnitDims>(context, options);
829 // TODO: Patterns unrelated to unit dim folding should be factored out.
830 linalg::FillOp::getCanonicalizationPatterns(patterns, context);
831 tensor::EmptyOp::getCanonicalizationPatterns(patterns, context);
835}
836
848
851 patterns.add<MoveInitOperandsToInput>(patterns.getContext());
852}
853
854namespace {
855/// Pass that removes unit-extent dims within generic ops.
856struct LinalgFoldUnitExtentDimsPass
858 LinalgFoldUnitExtentDimsPass> {
860 LinalgFoldUnitExtentDimsPass>::LinalgFoldUnitExtentDimsPassBase;
861 void runOnOperation() override {
862 Operation *op = getOperation();
863 MLIRContext *context = op->getContext();
866 if (useRankReducingSlices) {
867 options.rankReductionStrategy = linalg::ControlDropUnitDims::
868 RankReductionStrategy::ExtractInsertSlice;
869 }
872 (void)applyPatternsGreedily(op, std::move(patterns));
873 }
874};
875
876} // namespace
877
878namespace {
879
880/// Returns reassociation indices for collapsing/expanding a
881/// tensor of rank `rank` at position `pos`.
883getReassociationForReshapeAtDim(int64_t rank, int64_t pos) {
884 SmallVector<ReassociationIndices> reassociation(rank - 1, {0, 1});
885 bool lastDim = pos == rank - 1;
886 if (rank > 2) {
887 for (int64_t i = 0; i < rank - 1; i++) {
888 if (i == pos || (lastDim && i == pos - 1))
889 reassociation[i] = ReassociationIndices{i, i + 1};
890 else if (i < pos)
891 reassociation[i] = ReassociationIndices{i};
892 else
893 reassociation[i] = ReassociationIndices{i + 1};
894 }
895 }
896 return reassociation;
897}
898
899/// Returns a collapsed `val` where the collapsing occurs at dim `pos`.
900/// If `pos < 0`, then don't collapse.
901static Value collapseSingletonDimAt(PatternRewriter &rewriter, Value val,
902 int64_t pos) {
903 if (pos < 0)
904 return val;
905 auto valType = cast<ShapedType>(val.getType());
906 SmallVector<int64_t> collapsedShape(valType.getShape());
907 collapsedShape.erase(collapsedShape.begin() + pos);
908 return collapseValue(
909 rewriter, val.getLoc(), val, collapsedShape,
910 getReassociationForReshapeAtDim(valType.getRank(), pos),
912}
913
914/// Base class for all rank reduction patterns for contraction ops
915/// with unit dimensions. All patterns should convert one named op
916/// to another named op. Intended to reduce only one iteration space dim
917/// at a time.
918/// Reducing multiple dims will happen with recusive application of
919/// pattern rewrites.
920template <typename FromOpTy, typename ToOpTy>
921struct RankReduceContractionOps : OpRewritePattern<FromOpTy> {
922 using OpRewritePattern<FromOpTy>::OpRewritePattern;
923
924 /// Collapse all collapsable operands.
925 SmallVector<Value>
926 collapseOperands(PatternRewriter &rewriter, ArrayRef<Value> operands,
927 ArrayRef<int64_t> operandCollapseDims) const {
928 assert(operandCollapseDims.size() == 3 && operands.size() == 3 &&
929 "expected 3 operands and dims");
930 return llvm::map_to_vector(
931 llvm::zip(operands, operandCollapseDims), [&](auto pair) {
932 return collapseSingletonDimAt(rewriter, std::get<0>(pair),
933 std::get<1>(pair));
934 });
935 }
936
937 /// Expand result tensor.
938 Value expandResult(PatternRewriter &rewriter, Value result,
939 RankedTensorType expandedType, int64_t dim) const {
940 return tensor::ExpandShapeOp::create(
941 rewriter, result.getLoc(), expandedType, result,
942 getReassociationForReshapeAtDim(expandedType.getRank(), dim));
943 }
944
945 LogicalResult matchAndRewrite(FromOpTy contractionOp,
946 PatternRewriter &rewriter) const override {
947 if (contractionOp.hasUserDefinedMaps()) {
948 return rewriter.notifyMatchFailure(
949 contractionOp, "ops with user-defined maps are not supported");
950 }
951
952 auto loc = contractionOp.getLoc();
953 auto inputs = contractionOp.getDpsInputs();
954 auto inits = contractionOp.getDpsInits();
955 if (inputs.size() != 2 || inits.size() != 1)
956 return rewriter.notifyMatchFailure(contractionOp,
957 "expected 2 inputs and 1 init");
958 auto lhs = inputs[0];
959 auto rhs = inputs[1];
960 auto init = inits[0];
961 SmallVector<Value> operands{lhs, rhs, init};
962
963 SmallVector<int64_t> operandUnitDims;
964 if (failed(getOperandUnitDims(contractionOp, operandUnitDims)))
965 return rewriter.notifyMatchFailure(contractionOp,
966 "no reducable dims found");
967
968 SmallVector<Value> collapsedOperands =
969 collapseOperands(rewriter, operands, operandUnitDims);
970 Value collapsedLhs = collapsedOperands[0];
971 Value collapsedRhs = collapsedOperands[1];
972 Value collapsedInit = collapsedOperands[2];
973 SmallVector<Type, 1> collapsedResultTy;
974 if (isa<RankedTensorType>(collapsedInit.getType()))
975 collapsedResultTy.push_back(collapsedInit.getType());
976 auto collapsedOp = ToOpTy::create(rewriter, loc, collapsedResultTy,
977 ValueRange{collapsedLhs, collapsedRhs},
978 ValueRange{collapsedInit});
979 for (auto attr : contractionOp->getAttrs()) {
980 if (attr.getName() == LinalgDialect::kMemoizedIndexingMapsAttrName ||
981 attr.getName() == "indexing_maps")
982 continue;
983 collapsedOp->setAttr(attr.getName(), attr.getValue());
984 }
985
986 auto results = contractionOp.getResults();
987 assert(results.size() < 2 && "expected at most one result");
988 if (results.empty()) {
989 rewriter.replaceOp(contractionOp, collapsedOp);
990 } else {
991 rewriter.replaceOp(
992 contractionOp,
993 expandResult(rewriter, collapsedOp.getResultTensors()[0],
994 cast<RankedTensorType>(results[0].getType()),
995 operandUnitDims[2]));
996 }
997
998 return success();
999 }
1000
1001 /// Populate `operandUnitDims` with 3 indices indicating the unit dim
1002 /// for each operand that should be collapsed in this pattern. If an
1003 /// operand shouldn't be collapsed, the index should be negative.
1004 virtual LogicalResult
1005 getOperandUnitDims(LinalgOp op,
1006 SmallVectorImpl<int64_t> &operandUnitDims) const = 0;
1007};
1008
1009/// Patterns for unbatching batched contraction ops
1010template <typename FromOpTy, typename ToOpTy>
1011struct RankReduceToUnBatched : RankReduceContractionOps<FromOpTy, ToOpTy> {
1012 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1013
1014 /// Look for unit batch dims to collapse.
1015 LogicalResult
1016 getOperandUnitDims(LinalgOp op,
1017 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1018 FailureOr<ContractionDimensions> maybeContractionDims =
1020 if (failed(maybeContractionDims)) {
1021 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1022 return failure();
1023 }
1024 ContractionDimensions contractionDims = maybeContractionDims.value();
1025
1026 if (contractionDims.batch.size() != 1)
1027 return failure();
1028 auto batchDim = contractionDims.batch[0];
1029 SmallVector<std::pair<Value, unsigned>, 3> bOperands;
1030 op.mapIterationSpaceDimToAllOperandDims(batchDim, bOperands);
1031 if (bOperands.size() != 3 || llvm::any_of(bOperands, [](auto pair) {
1032 return cast<ShapedType>(std::get<0>(pair).getType())
1033 .getShape()[std::get<1>(pair)] != 1;
1034 })) {
1035 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1036 return failure();
1037 }
1038
1039 operandUnitDims = SmallVector<int64_t>{std::get<1>(bOperands[0]),
1040 std::get<1>(bOperands[1]),
1041 std::get<1>(bOperands[2])};
1042 return success();
1043 }
1044};
1045
1046/// Patterns for reducing non-batch dimensions
1047template <typename FromOpTy, typename ToOpTy>
1048struct RankReduceMatmul : RankReduceContractionOps<FromOpTy, ToOpTy> {
1049 using RankReduceContractionOps<FromOpTy, ToOpTy>::RankReduceContractionOps;
1050
1051 /// Helper for determining whether the lhs/init or rhs/init are reduced.
1052 static bool constexpr reduceLeft =
1053 (std::is_same_v<FromOpTy, BatchMatmulOp> &&
1054 std::is_same_v<ToOpTy, BatchVecmatOp>) ||
1055 (std::is_same_v<FromOpTy, MatmulOp> &&
1056 std::is_same_v<ToOpTy, VecmatOp>) ||
1057 (std::is_same_v<FromOpTy, MatvecOp> && std::is_same_v<ToOpTy, DotOp>);
1058
1059 /// Look for non-batch spatial dims to collapse.
1060 LogicalResult
1061 getOperandUnitDims(LinalgOp op,
1062 SmallVectorImpl<int64_t> &operandUnitDims) const override {
1063 FailureOr<ContractionDimensions> maybeContractionDims =
1065 if (failed(maybeContractionDims)) {
1066 LLVM_DEBUG(llvm::dbgs() << "could not infer contraction dims");
1067 return failure();
1068 }
1069 ContractionDimensions contractionDims = maybeContractionDims.value();
1070
1071 if constexpr (reduceLeft) {
1072 auto m = contractionDims.m[0];
1073 SmallVector<std::pair<Value, unsigned>, 2> mOperands;
1074 op.mapIterationSpaceDimToAllOperandDims(m, mOperands);
1075 if (mOperands.size() != 2)
1076 return failure();
1077 if (llvm::all_of(mOperands, [](auto pair) {
1078 return cast<ShapedType>(std::get<0>(pair).getType())
1079 .getShape()[std::get<1>(pair)] == 1;
1080 })) {
1081 operandUnitDims = SmallVector<int64_t>{std::get<1>(mOperands[0]), -1,
1082 std::get<1>(mOperands[1])};
1083 return success();
1084 }
1085 } else {
1086 auto n = contractionDims.n[0];
1087 SmallVector<std::pair<Value, unsigned>, 2> nOperands;
1088 op.mapIterationSpaceDimToAllOperandDims(n, nOperands);
1089 if (nOperands.size() != 2)
1090 return failure();
1091 if (llvm::all_of(nOperands, [](auto pair) {
1092 return cast<ShapedType>(std::get<0>(pair).getType())
1093 .getShape()[std::get<1>(pair)] == 1;
1094 })) {
1095 operandUnitDims = SmallVector<int64_t>{-1, std::get<1>(nOperands[0]),
1096 std::get<1>(nOperands[1])};
1097 return success();
1098 }
1099 }
1100 LLVM_DEBUG(llvm::dbgs() << "specified unit dims not found");
1101 return failure();
1102 }
1103};
1104
1105} // namespace
1106
1109 MLIRContext *context = patterns.getContext();
1110 // Unbatching patterns for unit batch size
1111 patterns.add<RankReduceToUnBatched<BatchMatmulOp, MatmulOp>>(context);
1112 patterns.add<RankReduceToUnBatched<BatchMatvecOp, MatvecOp>>(context);
1113 patterns.add<RankReduceToUnBatched<BatchVecmatOp, VecmatOp>>(context);
1114
1115 // Non-batch rank 1 reducing patterns
1116 patterns.add<RankReduceMatmul<MatmulOp, VecmatOp>>(context);
1117 patterns.add<RankReduceMatmul<MatmulOp, MatvecOp>>(context);
1118 // Batch rank 1 reducing patterns
1119 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchVecmatOp>>(context);
1120 patterns.add<RankReduceMatmul<BatchMatmulOp, BatchMatvecOp>>(context);
1121
1122 // Non-batch rank 0 reducing patterns
1123 patterns.add<RankReduceMatmul<MatvecOp, DotOp>>(context);
1124 patterns.add<RankReduceMatmul<VecmatOp, DotOp>>(context);
1125}
return success()
static Value expandValue(RewriterBase &rewriter, Location loc, Value result, Value origDest, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Expand the given value so that the type matches the type of origDest.
static void replaceUnitDimIndexOps(GenericOp genericOp, const llvm::SmallDenseSet< unsigned > &unitDims, RewriterBase &rewriter)
Implements a pass that canonicalizes the uses of unit-extent dimensions for broadcasting.
static UnitExtentReplacementInfo dropUnitExtentFromOperandMetadata(MLIRContext *context, IndexingMapOpInterface op, OpOperand *opOperand, llvm::SmallDenseMap< unsigned, unsigned > &oldDimsToNewDimsMap, ArrayRef< AffineExpr > dimReplacements)
static void populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns that are used to canonicalize the use of unit-extent dims for broadcasting.
static void populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
static Value collapseValue(RewriterBase &rewriter, Location loc, Value operand, ArrayRef< int64_t > targetShape, ArrayRef< ReassociationIndices > reassociation, ControlDropUnitDims::RankReductionStrategy rankReductionStrategy)
Collapse the given value so that the type matches the type of origOutput.
lhs
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
static llvm::ManagedStatic< PassManagerOptions > options
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
Base type for affine expression.
Definition AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumSymbols() const
ArrayRef< AffineExpr > getResults() const
AffineMap replaceDimsAndSymbols(ArrayRef< AffineExpr > dimReplacements, ArrayRef< AffineExpr > symReplacements, unsigned numResultDims, unsigned numResultSyms) const
This method substitutes any uses of dimensions and symbols (e.g.
This class represents an argument of a Block.
Definition Value.h:309
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument addArgument(Type type, Location loc)
Add one value to the argument list.
Definition Block.cpp:153
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
void map(Value from, Value to)
Inserts a new mapping for 'from' to 'to'.
Definition IRMapping.h:30
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
Operation * clone(Operation &op, IRMapping &mapper)
Creates a deep copy of the specified operation, remapping any operands that use values outside of the...
Definition Builders.cpp:562
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void createOrFold(SmallVectorImpl< Value > &results, Location location, Args &&...args)
Create an operation of specific op type at the current insertion point, and immediately try to fold i...
Definition Builders.h:526
void setInsertionPointAfterValue(Value val)
Sets the insertion point to the node after the specified value.
Definition Builders.h:421
This class represents an operand of an operation.
Definition Value.h:257
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
This class represents the benefit of a pattern match in a unitless scheme that ranges from 0 (very li...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Location getLoc() const
Return the location of this value.
Definition Value.cpp:24
Specialization of arith.constant op that returns an integer of index type.
Definition Arith.h:113
void populateMoveInitOperandsToInputPattern(RewritePatternSet &patterns)
A pattern that converts init operands to input operands.
std::function< IndexingMapOpInterface( Location loc, OpBuilder &, IndexingMapOpInterface, ArrayRef< Value > newOperands, ArrayRef< AffineMap > newIndexingMaps, const llvm::SmallDenseSet< unsigned > &droppedDims)> DroppedUnitDimsBuilder
Definition Transforms.h:544
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns)
Adds patterns that reduce the rank of named contraction ops that have unit dimensions in the operand(...
std::optional< SmallVector< ReassociationIndices > > getReassociationMapForFoldingUnitDims(ArrayRef< OpFoldResult > mixedSizes)
Get the reassociation maps to fold the result of a extract_slice (or source of a insert_slice) operat...
Definition Utils.cpp:1790
void populateFoldUnitExtentDimsPatterns(RewritePatternSet &patterns, ControlDropUnitDims &options)
Patterns to fold unit-extent dimensions in operands/results of linalg ops on tensors via reassociativ...
FailureOr< ContractionDimensions > inferContractionDims(LinalgOp linalgOp)
Find at least 2 parallel (m and n) and 1 reduction (k) dimension candidates that form a matmul subcom...
FailureOr< DropUnitDimsResult > dropUnitDims(RewriterBase &rewriter, IndexingMapOpInterface op, const DroppedUnitDimsBuilder &droppedUnitDimsBuilder, const ControlDropUnitDims &options)
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:388
void populateResolveRankedShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
void populateResolveShapedTypeResultDimsPatterns(RewritePatternSet &patterns)
Appends patterns that resolve memref.dim operations with values that are defined by operations that i...
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns, bool foldSingleUseOnly=false)
Populates patterns with patterns that fold tensor.empty with its consumers.
OpFoldResult getMixedSize(OpBuilder &builder, Location loc, Value value, int64_t dim)
Return the dimension of the given tensor value.
Definition TensorOps.cpp:57
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given tensor value.
Definition TensorOps.cpp:66
Include the generated interface declarations.
AffineMap concatAffineMaps(ArrayRef< AffineMap > maps, MLIRContext *context)
Concatenates a list of maps into a single AffineMap, stepping over potentially empty maps.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineMap inversePermutation(AffineMap map)
Returns a map of codomain to domain dimensions such that the first codomain dimension for a particula...
llvm::SetVector< T, Vector, Set, N > SetVector
Definition LLVM.h:131
const FrozenRewritePatternSet & patterns
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Compute the modified metadata for an operands of operation whose unit dims are being dropped.
SmallVector< ReassociationIndices > reassociation
SmallVector< int64_t > targetShape
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
SmallVector< unsigned, 2 > batch
Transformation to drop unit-extent dimensions from linalg.generic operations.
Definition Transforms.h:521