MLIR 22.0.0git
DataLayoutPropagation.cpp
Go to the documentation of this file.
1//===- DataLayoutPropagation.cpp -----------------------------------------===///
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/IR/Dominance.h"
18#include "llvm/ADT/SetOperations.h"
19#include "llvm/ADT/SetVector.h"
20#include "llvm/ADT/TypeSwitch.h"
21#include "llvm/Support/Debug.h"
22#include <optional>
23
24namespace mlir {
25#define GEN_PASS_DEF_LINALGDATALAYOUTPROPAGATION
26#include "mlir/Dialect/Linalg/Passes.h.inc"
27} // namespace mlir
28
29using namespace mlir;
30using namespace mlir::linalg;
31
32#define DEBUG_TYPE "linalg-data-layout-propagation"
33
34namespace {
35
36static bool hasGatherSemantics(linalg::GenericOp genericOp) {
37 for (Operation &op : genericOp.getBody()->getOperations())
38 if (isa<tensor::ExtractOp, linalg::IndexOp>(op))
39 return true;
40 return false;
41}
42
43// The struct contains the infomation about mapping packing information to
44// the iteration domain of Linalg ops.
45struct PackInfo {
46 int64_t getNumTiledLoops() const { return tileToPointMapping.size(); };
47 // InnerDimsPos on iteration domain, which follows the order in pack ops.
48 SmallVector<int64_t> tiledDimsPos;
49 // The sizes of tiling data dimensions on iteration domain.
50 llvm::DenseMap<int64_t, OpFoldResult> domainDimAndTileMapping;
51 // The mapping from a dimension of iteration domain to the corresponding inner
52 // tiling dimension on iteration domain.
53 llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
54 // The permutation of outer dims (on domain).
55 SmallVector<int64_t> outerDimsOnDomainPerm;
56};
57
58template <typename OpTy>
59static FailureOr<PackInfo>
60getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
61 OpTy packOrUnPackOp) {
62 static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
63 "applies to only pack or unpack operations");
64 LLVM_DEBUG(
65 { llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
66
67 AffineMap indexingMap = genericOp.getMatchingIndexingMap(opOperand);
68 SmallVector<AffineMap> indexingMaps = genericOp.getIndexingMapsArray();
70 genericOp.getIteratorTypesArray();
71
72 PackInfo packInfo;
73 int64_t origNumDims = indexingMap.getNumDims();
74 SmallVector<AffineExpr> exprs(indexingMap.getResults());
75 ArrayRef<int64_t> innerDimsPos = packOrUnPackOp.getInnerDimsPos();
76 for (auto [index, innerDimPos, tileSize] :
77 llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
78 innerDimsPos, packOrUnPackOp.getMixedTiles())) {
79 auto expr = exprs[innerDimPos];
80 if (!isa<AffineDimExpr>(expr))
81 return failure();
82 int64_t domainDimPos =
83 cast<AffineDimExpr>(exprs[innerDimPos]).getPosition();
84 if (!isParallelIterator(iterators[domainDimPos]))
85 return failure();
86 packInfo.tiledDimsPos.push_back(domainDimPos);
87 packInfo.domainDimAndTileMapping[domainDimPos] = tileSize;
88 packInfo.tileToPointMapping[domainDimPos] = origNumDims + index;
89 LLVM_DEBUG({
90 llvm::dbgs() << "map innerDimPos=" << innerDimPos
91 << " to iteration dimension (d" << domainDimPos << ", d"
92 << packInfo.tileToPointMapping[domainDimPos]
93 << "), which has size=("
94 << packInfo.domainDimAndTileMapping[domainDimPos] << ")\n";
95 });
96 }
97
98 // Bail out if a tiled dimension is present in a map but not as an affine dim
99 // expression.
100 auto areAllAffineDimExpr = [&](int dim) {
101 for (AffineMap map : indexingMaps) {
102 if (llvm::any_of(map.getResults(), [dim](AffineExpr expr) {
103 return expr.isFunctionOfDim(dim) && !isa<AffineDimExpr>(expr);
104 })) {
105 return false;
106 }
107 }
108 return true;
109 };
110 for (int64_t i : packInfo.tiledDimsPos)
111 if (!areAllAffineDimExpr(i))
112 return failure();
113
114 // Get the outer dims perm on the iteration domain. Start by identifying the
115 // set of domain dims affected by the outer permutation along with the
116 // permuted ordering for those dims. Then the full outer dims permutation can
117 // be constructed by replacing the affected dims with the permuted result in a
118 // numLoops-rank identity. e.g.
119 // outerDimsPerm = [1, 2, 0]
120 // indexingMap = (d0, d1, d2, d3, d4) -> (d1, d4, d3)
121 //
122 // permutedOuterDims = [4, 3, 1]
123 // outerDimsOnDomainPerm = [0, 4, 2, 3, 1]
124 //
125 // Non-affine dim expressions must not be permuted by the outer dims
126 // permutation.
127 SmallVector<int64_t> permutedOuterDims;
128 for (auto [index, dim] : llvm::enumerate(packOrUnPackOp.getOuterDimsPerm())) {
129 auto permutedExpr = indexingMap.getResult(dim);
130 if (auto dimExpr = dyn_cast<AffineDimExpr>(permutedExpr)) {
131 permutedOuterDims.push_back(dimExpr.getPosition());
132 continue;
133 }
134
135 // TODO: Allow propagation with transposes on non affine dim expressions,
136 // e.g. d0 + d1 which implies transposing both dims simultaneously while
137 // maintaining the relative position between them.
138 if (static_cast<int64_t>(index) != dim)
139 return failure();
140 }
141 if (!permutedOuterDims.empty()) {
142 int64_t outerDimIndex = 0;
143 llvm::DenseSet<int64_t> permutedDomainDims(permutedOuterDims.begin(),
144 permutedOuterDims.end());
145 for (int i = 0, e = indexingMap.getNumDims(); i < e; i++)
146 packInfo.outerDimsOnDomainPerm.push_back(
147 permutedDomainDims.contains(i) ? permutedOuterDims[outerDimIndex++]
148 : i);
149 LLVM_DEBUG({
150 llvm::dbgs() << "map outer dimsDimsPerm to ";
151 for (auto dim : packInfo.outerDimsOnDomainPerm)
152 llvm::dbgs() << dim << " ";
153 llvm::dbgs() << "\n";
154 });
155 }
156
157 return packInfo;
158}
159
160static SmallVector<int64_t> computeOuterDims(ArrayRef<int64_t> perm,
161 ArrayRef<AffineExpr> exprs) {
162 // Compute `outer_dims_perm`. See example:
163 // current exprs : (d0, d1, d2, d3) -> (d2, d3)
164 // perm : [0, 3, 1, 2]
165 // First map d2, d3 with their position in the array as:
166 // currentPositionTileLoops: dim | pos
167 // d2 | 0
168 // d3 | 1
169 // then scan `perm` in order and get the `outer_dims_perm`
170 // to be used, here it would be [1, 0].
171 assert(!perm.empty() && "expect perm not to be empty");
172 assert(!exprs.empty() && "expect exprs not to be empty");
173 if (exprs.size() == 1)
174 return {};
175 SmallVector<int64_t> outerDimsPerm;
176 DenseMap<int64_t, int64_t> currentPositionTileLoops;
177 for (auto [pos, expr] : llvm::enumerate(exprs)) {
178 // Here we rely on the assumption that the outer dims permutation
179 // when propagating currently requires that non-affine dim expressions
180 // are not permuted, thus allowing the identity assignment below.
181 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr))
182 currentPositionTileLoops[dimExpr.getPosition()] = pos;
183 else
184 currentPositionTileLoops[pos] = pos;
185 }
186 for (int64_t loopIdx : perm) {
187 if (currentPositionTileLoops.count(loopIdx))
188 outerDimsPerm.push_back(currentPositionTileLoops.lookup(loopIdx));
189 }
190 return outerDimsPerm;
191}
192
193struct PackedOperandDetails {
194 SmallVector<OpFoldResult> innerTileSizes;
195 SmallVector<int64_t> innerDimsPos;
196 SmallVector<int64_t> outerDimsPerm;
197 AffineMap indexingMap;
198};
199
200/// Helper function for getOrCreatePackedViewOfOperand that populates
201/// the details of the packedOperand that needs to be formed and also
202/// returns if the packing would require padding.
203static bool getPackedOperandDetails(
204 OpBuilder &b, PackInfo packInfo, GenericOp genericOp, OpOperand *opOperand,
206 PackedOperandDetails currOperandDetails;
207 int64_t numOrigLoops = genericOp.getNumLoops();
208 int64_t numInnerLoops = packInfo.getNumTiledLoops();
209 int64_t numLoops = numOrigLoops + numInnerLoops;
210 AffineMap origIndexingMap = genericOp.getMatchingIndexingMap(opOperand);
211 llvm::DenseMap<int64_t, int64_t> domainDimToOperandDim;
212 SmallVector<AffineExpr> exprs(origIndexingMap.getResults());
213
214 // If the OpOperand is a scalar or a zero-rank tensor, no need to pack.
215 if (genericOp.isScalar(opOperand) || exprs.empty()) {
216 currOperandDetails.indexingMap =
217 AffineMap::get(numLoops, 0, exprs, b.getContext());
218 packedOperandMap[opOperand] = currOperandDetails;
219 return false;
220 }
221
222 // Step 1. Construct the information of packing data dimensions; append inner
223 // dimensions to the indexing maps for the operand.
224 for (auto [index, expr] : llvm::enumerate(exprs)) {
225 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
226 int64_t dimPos = dimExpr.getPosition();
227 domainDimToOperandDim[dimPos] = index;
228 continue;
229 }
230 }
231 SmallVector<int64_t> innerDimsPos;
232 SmallVector<OpFoldResult> innerTileSizes;
233 for (auto dimPos : packInfo.tiledDimsPos) {
234 if (!domainDimToOperandDim.count(dimPos))
235 continue;
236 int64_t index = domainDimToOperandDim[dimPos];
237 innerTileSizes.push_back(packInfo.domainDimAndTileMapping[dimPos]);
238 innerDimsPos.push_back(index);
239 exprs.push_back(b.getAffineDimExpr(packInfo.tileToPointMapping[dimPos]));
240 }
241
242 // Step 2. Handle outer dim permutations.
243 SmallVector<int64_t> outerDimsPerm;
244 if (!packInfo.outerDimsOnDomainPerm.empty()) {
245 outerDimsPerm = computeOuterDims(packInfo.outerDimsOnDomainPerm, exprs);
246
247 // Step 2.1: Fold transpose into the linalg.generic.
248 SmallVector<int64_t> inversedOuterPerm =
249 invertPermutationVector(packInfo.outerDimsOnDomainPerm);
250 for (auto i : llvm::seq<unsigned>(0, origIndexingMap.getNumResults())) {
251 if (auto dimExpr = dyn_cast<AffineDimExpr>(exprs[i])) {
252 int64_t dimPos = dimExpr.getPosition();
253 exprs[i] = b.getAffineDimExpr(inversedOuterPerm[dimPos]);
254 continue;
255 }
256 assert(isa<AffineConstantExpr>(exprs[i]) &&
257 "Attempted to permute non-constant and non-affine dim expression");
258 }
259 // Step 2.2: Undo the transposition on `exprs` and propagate the
260 // transposition on the pack using outerDimsPerm.
261 if (!outerDimsPerm.empty()) {
262 SmallVector<AffineExpr> auxVec = exprs;
263 for (const auto &en : enumerate(outerDimsPerm))
264 auxVec[en.index()] = exprs[en.value()];
265 exprs = auxVec;
266 }
267 }
268 currOperandDetails.indexingMap =
269 AffineMap::get(numLoops, 0, exprs, b.getContext());
270
271 // The operand does not have dimensions that relates to pack op.
272 if (innerDimsPos.empty() && outerDimsPerm.empty()) {
273 packedOperandMap[opOperand] = currOperandDetails;
274 return false;
275 }
276 auto inputType = cast<RankedTensorType>(opOperand->get().getType());
277
278 auto maybeIntInnerTileSizes =
279 llvm::map_to_vector(innerTileSizes, [](OpFoldResult ofr) -> int64_t {
280 std::optional<int64_t> maybeCst = getConstantIntValue(ofr);
281 return maybeCst.value_or(ShapedType::kDynamic);
282 });
283 bool requirePadding = linalg::PackOp::requirePaddingValueStrict(
284 inputType.getShape(), innerDimsPos,
285 linalg::PackOp::inferPackedType(inputType, maybeIntInnerTileSizes,
286 innerDimsPos, outerDimsPerm)
287 .getShape(),
288 outerDimsPerm, innerTileSizes);
289 currOperandDetails.innerDimsPos = innerDimsPos;
290 currOperandDetails.innerTileSizes = innerTileSizes;
291 currOperandDetails.outerDimsPerm = outerDimsPerm;
292 packedOperandMap[opOperand] = currOperandDetails;
293
294 return requirePadding;
295}
296
297/// Returns a tuple for packed operand and indexing_map with the assumptions:
298/// 1) The generic op is the producer of the pack op.
299/// 2) The generic op has only one result.
300/// If the operand is a scalar or packing dimensions are all irrelevant to the
301/// operand, the operand and the updated indexing map will be returned.
302/// Otherwise, it returns the packed operand and the updated indexing map. E.g.,
303///
304/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
305/// #map1 = affine_map<(d0, d1) -> (d0)>
306/// #map2 = affine_map<(d0, d1) -> (d1)>
307/// %0 = linalg.generic {indexing_maps = [#map1, #map2, #map0],
308/// iterator_types = ["parallel", "parallel"]}
309/// ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
310/// outs(%init : tensor<?x?xf32>) {
311/// ^bb0(%arg3: f32, %arg4: f32, %arg5: f32):
312/// %4 = arith.addf %arg3, %arg4 : f32
313/// linalg.yield %4 : f32
314/// } -> tensor<?x?xf32>
315/// %1 = linalg.pack %0
316/// inner_dims_pos = [0, 1]
317/// inner_tiles = [8, 2]
318/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
319///
320/// Taking the first input operand as an example, the inner tile size of d1 is
321/// 8. Thus, the below operation and `affine_map<(d0, d1, d2, d3)> ->
322/// affine_map<(d1, d3)>` will be returned.
323///
324/// %pack = linalg.pack %arg0
325/// inner_dims_pos = [0]
326/// inner_tiles = [8]
327/// into %init : tensor<?xf32> -> tensor<?x8xf32>
328static std::tuple<Value, AffineMap> getOrCreatePackedViewOfOperand(
329 OpBuilder &b, Location loc, OpOperand *opOperand,
330 const DenseMap<OpOperand *, PackedOperandDetails> &packedOperandMap) {
331 assert(packedOperandMap.contains(opOperand) &&
332 "packed operand details expected to be populated");
333 auto currOperandDetails = packedOperandMap.at(opOperand);
334 auto innerDimsPos = currOperandDetails.innerDimsPos;
335 auto outerDimsPerm = currOperandDetails.outerDimsPerm;
336 auto innerTileSizes = currOperandDetails.innerTileSizes;
337 if (innerDimsPos.empty() && outerDimsPerm.empty())
338 return std::make_tuple(opOperand->get(), currOperandDetails.indexingMap);
339
340 auto empty = linalg::PackOp::createDestinationTensor(
341 b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
342 auto poison = ub::PoisonOp::create(
343 b, loc, getElementTypeOrSelf(opOperand->get().getType()));
344 Value packedOperand =
345 linalg::PackOp::create(b, loc, opOperand->get(), empty, innerDimsPos,
346 innerTileSizes, poison, outerDimsPerm);
347 return std::make_tuple(packedOperand, currOperandDetails.indexingMap);
348}
349
350/// This function is a helper subroutine to pack a genericOp and return it. It
351/// will create a new generic op with the packed operand and the packed output
352/// according to packInfo when we attempt to push down unpack or bubble up pack
353/// around it. Implicitly this will only work when a packInfo can be obtained.
354/// This make sure that we are only using this function on parallel permuted
355/// dimensions.
356static FailureOr<GenericOp>
357packGenericOp(RewriterBase &rewriter, GenericOp genericOp, Value dest,
358 AffineMap packedOutIndexingMap, const PackInfo &packInfo,
359 bool isFoldableUnpackPack, bool poisonPaddingOk) {
360 Location loc = genericOp.getLoc();
361 SmallVector<Value> inputOperands;
362 SmallVector<Value> inputOperandsFromUnpackedSource;
363 SmallVector<AffineMap> indexingMaps;
364 auto hasEquivalentTiles = [](PackOp packOp, UnPackOp unPackOp) {
365 return packOp.getOuterDimsPerm() == unPackOp.getOuterDimsPerm() &&
366 packOp.getInnerDimsPos() == unPackOp.getInnerDimsPos() &&
367 llvm::equal(packOp.getMixedTiles(), unPackOp.getMixedTiles());
368 };
370 bool requiresPadding = false;
371 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
372 requiresPadding |= getPackedOperandDetails(rewriter, packInfo, genericOp,
373 inputOperand, packedOperandMap);
374 }
375 if (requiresPadding && !poisonPaddingOk)
376 return failure();
377
378 for (OpOperand *inputOperand : genericOp.getDpsInputOperands()) {
379 auto [packedOperand, packedIndexingMap] = getOrCreatePackedViewOfOperand(
380 rewriter, loc, inputOperand, packedOperandMap);
381 auto unpackOp = inputOperand->get().getDefiningOp<linalg::UnPackOp>();
382 auto packOp = packedOperand.getDefiningOp<linalg::PackOp>();
383 if (packOp && unpackOp && hasEquivalentTiles(packOp, unpackOp)) {
384 inputOperandsFromUnpackedSource.push_back(unpackOp.getSource());
385 } else {
386 inputOperandsFromUnpackedSource.push_back(packedOperand);
387 }
388 inputOperands.push_back(packedOperand);
389 indexingMaps.push_back(packedIndexingMap);
390 }
391
392 // If the unpack->pack sequences can be folded, replace use the sources of
393 // the unpack ops in any unpack->pack chains on the generic op operands.
394 if (isFoldableUnpackPack) {
395 inputOperands = inputOperandsFromUnpackedSource;
396 if (auto destPack = dest.getDefiningOp<linalg::PackOp>()) {
397 auto destUnPack = destPack.getSource().getDefiningOp<linalg::UnPackOp>();
398 if (destUnPack && hasEquivalentTiles(destPack, destUnPack)) {
399 dest = destUnPack.getSource();
400 }
401 }
402 }
403
404 int64_t numInnerLoops = packInfo.getNumTiledLoops();
406 genericOp.getIteratorTypesArray();
407 iterTypes.append(numInnerLoops, utils::IteratorType::parallel);
408
409 indexingMaps.push_back(packedOutIndexingMap);
410
411 auto newGenericOp = linalg::GenericOp::create(
412 rewriter, loc, dest.getType(), inputOperands, dest, indexingMaps,
413 iterTypes,
414 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
415 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
416 newGenericOp.getRegion().begin());
417 return newGenericOp;
418}
419
420static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
421 return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) {
422 return genericOp.getMatchingBlockArgument(&operand).use_empty();
423 });
424}
425
426/// Bubbles up linalg.pack op through a producer generic op. This
427/// swap pack(generic) to generic(pack). The new generic op works on packed
428/// domain; pack ops are created for input and output operands. E.g.,
429///
430/// #map0 = affine_map<(d0, d1) -> (d0, d1)>
431/// %0 = tensor.dim %arg0, %c0 : tensor<?x?xf32>
432/// %1 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
433/// %2 = tensor.empty(%0, %1) : tensor<?x?xf32>
434/// %3 = linalg.generic {indexing_maps = [#map0, #map0],
435/// iterator_types = ["parallel", "parallel"]}
436/// ins(%arg0 : tensor<?x?xf32>)
437/// outs(%2 : tensor<?x?xf32>) {
438/// ^bb0(%arg3: f32, %arg4: f32):
439/// %4 = arith.addf %arg3, %arg3 : f32
440/// linalg.yield %4 : f32
441/// } -> tensor<?x?xf32>
442/// %4 = linalg.pack %3
443/// inner_dims_pos = [0, 1]
444/// inner_tiles = [8, 2]
445/// into %dest : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
446///
447/// will be converted to
448///
449/// #map = affine_map<()[s0] -> (s0 ceildiv 8)>
450/// #map1 = affine_map<()[s0] -> (s0 ceildiv 2)>
451/// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
452/// %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
453/// %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
454/// %0 = affine.apply #map()[%dim]
455/// %1 = affine.apply #map1()[%dim_0]
456/// %2 = tensor.empty(%0, %1) : tensor<?x?x8x2xf32>
457/// %pack = linalg.pack %arg0
458/// inner_dims_pos = [0, 1]
459/// inner_tiles = [8, 2]
460/// into %2 : tensor<?x?xf32> -> tensor<?x?x8x2xf32>
461/// %3 = linalg.generic {indexing_maps = [#map2, #map2],
462/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
463/// ins(%pack : tensor<?x?x8x2xf32>)
464/// outs(%arg1 : tensor<?x?x8x2xf32>) {
465/// ^bb0(%in: f32, %out: f32):
466/// %4 = arith.addf %in, %in : f32
467/// linalg.yield %4 : f32
468/// } -> tensor<?x?x8x2xf32>
469static FailureOr<GenericOp>
470bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
471 const ControlPropagationFn &controlFn,
472 bool poisonPaddingOk) {
473 auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
474 if (!genericOp)
475 return failure();
476
477 // User controlled propagation function.
478 if (!controlFn(&packOp.getSourceMutable()))
479 return failure();
480
481 // TODO: Enable propagation in the presence of linalg.index and
482 // tensor.extract, likely as a separate pattern as the pack information and
483 // propagation decision needs to be inferred from the region of the generic.
484 if (hasGatherSemantics(genericOp))
485 return failure();
486
487 // TODO: Relax the restriction. We are able to bubble up the pack op through
488 // multi-result generic op. It just needs more work.
489 if (genericOp.getNumResults() != 1)
490 return failure();
491
492 // Bail-out if the result of the generic has multiple uses, as bubbling up
493 // creates recomputation if the generic has multiple users.
494 // TODO: Enable the case where every use is an identical pack op as no
495 // recomputation is needed in that case.
496 if (!genericOp->getResult(0).hasOneUse())
497 return failure();
498
499 // TODO: Add an option for allowing padding values. It could introduce
500 // undefined behavior if we unconditionally propagate pack op through all
501 // the ops. E.g., if the padding value is zero and there are division ops in
502 // a generic op. Some values of padding area could be NaN (0/0).
503 if (packOp.getPaddingValue())
504 return failure();
505
506 OpOperand *opOperand = genericOp.getDpsInitOperand(0);
507 auto packInfo = getPackingInfoFromOperand(opOperand, genericOp, packOp);
508 if (failed(packInfo))
509 return failure();
510
511 // We want to move the pack not the generic.
512 OpBuilder::InsertionGuard guard(rewriter);
513 rewriter.setInsertionPoint(genericOp);
514
515 // We need to handle two cases:
516 // 1) The linalg.pack destination is a tensor.empty. If this is the case, we
517 // create a new tensor.empty to avoid breaking dominance, as we are moving the
518 // linalg.pack above the linalg.generic.
519 // 2) The destination is not a tensor.empty. In this case we can replace only
520 // if the destination of the linalg.pack dominates the linalg.generic.
521 Value packOpDest = packOp.getDest();
522 if (!packOpDest.hasOneUse())
523 return failure();
524 if (auto emptyOp = packOpDest.getDefiningOp<tensor::EmptyOp>()) {
525 packOpDest = tensor::EmptyOp::create(rewriter, genericOp->getLoc(),
526 emptyOp.getMixedSizes(),
527 emptyOp.getType().getElementType());
528 } else {
529 DominanceInfo dom(genericOp);
530 if (!dom.properlyDominates(packOpDest, genericOp))
531 return failure();
532 }
533
534 // Rebuild the indexing map for the corresponding init operand.
536 bool requiresPadding = getPackedOperandDetails(rewriter, *packInfo, genericOp,
537 opOperand, packedOperandMap);
538 if (requiresPadding && !poisonPaddingOk)
539 return failure();
540
541 auto [packedOutOperand, packedOutIndexingMap] =
542 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), opOperand,
543 packedOperandMap);
544 // Forward the new tensor.empty as a destination if it is one of the following
545 // situations:
546 // 1) The dps init operand is a tensor.empty.
547 // 2) The dps init is a write-only operand, i.e., it is not used in the
548 // genericOp
549 Value dest = packedOutOperand;
550 auto initTensor =
551 genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
552 if (initTensor || isGenericOutsNotUsed(genericOp)) {
553 dest = packOpDest;
554 }
555 // pack(unpack) isn't naively foldable because the unpack op can be from
556 // an arbitrary domain so we need to keep both.
557 return packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap,
558 *packInfo, /*isFoldableUnpackPack=*/false,
559 poisonPaddingOk);
560}
561
562/// Wrapper pattern that applies bubbleUpPackOpThroughGenericOp method.
563struct BubbleUpPackOpThroughGenericOpPattern
564 : public OpRewritePattern<linalg::PackOp> {
565public:
566 BubbleUpPackOpThroughGenericOpPattern(MLIRContext *context,
568 bool poisonPaddingOk)
569 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)),
570 poisonPaddingOk(std::move(poisonPaddingOk)) {}
571
572 LogicalResult matchAndRewrite(linalg::PackOp packOp,
573 PatternRewriter &rewriter) const override {
574 auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn,
575 poisonPaddingOk);
576 if (failed(genericOp))
577 return failure();
578 rewriter.replaceOp(packOp, genericOp->getResults());
579 return success();
580 }
581
582private:
583 ControlPropagationFn controlFn;
584 bool poisonPaddingOk;
585};
586
587/// Propagate a linalg.pack operation up through a tensor.pad. The idea is to
588/// add as many zero padding dimensions in `high` and `low` based on the number
589/// of point loops.
590class BubbleUpPackThroughPadOp final : public OpRewritePattern<linalg::PackOp> {
591public:
592 BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
593 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
594
595 LogicalResult matchAndRewrite(linalg::PackOp packOp,
596 PatternRewriter &rewriter) const override {
597 auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
598 if (!padOp)
599 return failure();
600
601 // User controlled propagation function.
602 if (!controlFn(&packOp.getSourceMutable()))
603 return failure();
604
605 // TODO: Enable padding when the padding values are the same.
606 if (packOp.getPaddingValue())
607 return failure();
608
609 // Fail for non-constant padding values. The body of the pad could
610 // depend on the padding indices and/or properties of the padded
611 // tensor so for now we fail.
612 // TODO: Support non-constant padding values.
613 Value paddingVal = padOp.getConstantPaddingValue();
614 if (!paddingVal)
615 return failure();
616
617 if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
618 return failure();
619
620 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
621
622 // Bail out if one of the padded dimension is a tiled one.
623 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
624 llvm::SmallBitVector innerDims(paddedDims.size());
625 for (int64_t dim : innerDimsPos)
626 innerDims.flip(dim);
627 if (paddedDims.anyCommon(innerDims))
628 return failure();
629
630 Location loc = padOp->getLoc();
631 OpBuilder::InsertionGuard guard(rewriter);
632 rewriter.setInsertionPoint(padOp);
633
634 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
635 SmallVector<OpFoldResult> mixedTiles = packOp.getMixedTiles();
636 auto empty = linalg::PackOp::createDestinationTensor(
637 rewriter, loc, padOp.getSource(), mixedTiles, innerDimsPos,
638 outerDimsPerm);
639 auto sourcePack = linalg::PackOp::create(
640 rewriter, loc, padOp.getSource(), empty, innerDimsPos, mixedTiles,
641 /*padding=*/std::nullopt, outerDimsPerm);
642
643 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
644 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
645 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
646 if (!outerDimsPerm.empty()) {
647 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
648 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
649 }
650 // The tiled dimensions were verified to be unpadded above, so here we
651 // just append 0 for the inner tile dimensions.
652 size_t pointLoopsSize = innerDimsPos.size();
653 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
654 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
655
656 auto newPadOp =
657 tensor::PadOp::create(rewriter, loc, /*result=*/Type(), sourcePack,
658 lowPad, highPad, paddingVal, padOp.getNofold());
659
660 // If the pad has more than one user, create an unpack on the new pad to
661 // replace the other uses.
662 if (!padOp->hasOneUse()) {
663 auto unpackEmpty = linalg::UnPackOp::createDestinationTensor(
664 rewriter, loc, newPadOp, mixedTiles, innerDimsPos, outerDimsPerm);
665 Value unpackedPad =
666 linalg::UnPackOp::create(rewriter, loc, newPadOp, unpackEmpty,
667 innerDimsPos, mixedTiles, outerDimsPerm);
668 rewriter.replaceAllUsesExcept(padOp, unpackedPad, sourcePack);
669 }
670
671 // Replace the pack with the new pad.
672 rewriter.replaceOp(packOp, newPadOp.getResult());
673
674 return success();
675 }
676
677private:
678 ControlPropagationFn controlFn;
679};
680
681/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
682///
683/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
684/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
685/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
686/// non-unit projected dims in pos [2, 3] is 2.
687///
688/// If all candidates in a reassociation are unit dims, it chooses the
689/// inner-most dim pos.
691projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
692 ArrayRef<ReassociationIndices> reassocIndices,
693 ArrayRef<int64_t> targetShape) {
694 SmallVector<int64_t> projectedDimsPos;
695 for (auto pos : dimsPos) {
696 // In the case all dims are unit, this will return the inner-most one.
697 int64_t projectedPos = reassocIndices[pos].back();
698 for (auto i : llvm::reverse(reassocIndices[pos])) {
699 int64_t dim = targetShape[i];
700 if (dim > 1 || ShapedType::isDynamic(dim)) {
701 projectedPos = i;
702 break;
703 }
704 }
705 projectedDimsPos.push_back(projectedPos);
706 }
707 return projectedDimsPos;
708}
709
710/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
711static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
713 ArrayRef<int64_t> tileSizes) {
714 for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
715 int64_t dim = shape[pos];
716 if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
717 return false;
718 }
719 return true;
720}
721
722/// Permutate the reassociation indices and reindex them in the sequence order.
723/// Returns the next dim pos in the sequence.
724///
725/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
726/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
727/// [[0], [1, 2]].
728static int64_t applyPermutationAndReindexReassoc(
729 SmallVector<ReassociationIndices> &reassocIndices,
730 ArrayRef<int64_t> permutation) {
731 if (!permutation.empty())
732 applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
733 int64_t nextPos = 0;
734 for (ReassociationIndices &indices : reassocIndices) {
735 for (auto &index : indices) {
736 index = nextPos;
737 nextPos += 1;
738 }
739 }
740 return nextPos;
741}
742
743/// Bubble up pack op through collapse shape op when the packed dims can be
744/// projected to the dims before collapsing. This is possible when the inner
745/// tile sizes can divide the projected dims.
746///
747/// For example:
748///
749/// %collapsed = tensor.collapse_shape %in [[0, 1], 2]
750/// : tensor<?x16x4xf32> into tensor<?x4xf32>
751/// %pack = linalg.pack %collapsed outer_dims_perm = [0, 1]
752/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
753/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
754///
755/// can be transformed into:
756///
757/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
758/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
759/// : tensor<?x16x4xf32> -> tensor<?x2x4x8x1xf32>
760/// %collapsed = tensor.collapse_shape %pack [[0, 1], 2, 3, 4]
761/// : tensor<?x2x4x8x1xf32> into tensor<?x4x8x1>
762static LogicalResult
763bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
764 linalg::PackOp packOp,
765 PatternRewriter &rewriter) {
766 SmallVector<int64_t> innerTileSizes = packOp.getStaticTiles();
767 ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
768 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
769
770 ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
771 SmallVector<ReassociationIndices> reassocIndices =
772 collapseOp.getReassociationIndices();
773 // Project inner tile pos to the dim pos before collapsing. For example, if
774 // dims [x, y] is collapsed into [z], packing on dim z can be projected back
775 // to pack on dim y.
776 //
777 // Project to inner-most non-unit dims to increase the chance that they can be
778 // divided by the inner tile sizes. This is correct because for [..., x, 1],
779 // packing on dim 1 is equivalent to packing on dim x.
780 SmallVector<int64_t> projectedInnerDimsPos =
781 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
782
783 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
784 innerTileSizes)) {
785 return failure();
786 }
787 // Expand the outer dims permutation with the associated source dims for the
788 // new permutation after bubbling. This is because moving a collapsed dim is
789 // equivalent to moving the associated source dims together.
790 SmallVector<int64_t> newOuterDimsPerm;
791 for (auto outerPos : outerDimsPerm)
792 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
793
794 auto emptyOp = linalg::PackOp::createDestinationTensor(
795 rewriter, packOp.getLoc(), collapseOp.getSrc(), packOp.getMixedTiles(),
796 projectedInnerDimsPos, newOuterDimsPerm);
797 auto newPackOp = linalg::PackOp::create(
798 rewriter, packOp.getLoc(), collapseOp.getSrc(), emptyOp,
799 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
800 newOuterDimsPerm);
801
802 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
803 // First apply the permutation on the reassociations of the outer dims.
804 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
805 // -> [[0], [1, 2]]
806 int64_t nextPos =
807 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
808 // Then add direct mapping for the inner tile dims.
809 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
810 newReassocIndices.push_back({nextPos});
811 nextPos += 1;
812 }
813
814 auto newCollapseOp = tensor::CollapseShapeOp::create(
815 rewriter, collapseOp.getLoc(), packOp.getType(), newPackOp,
816 newReassocIndices);
817 rewriter.replaceOp(packOp, newCollapseOp);
818
819 return success();
820}
821
822/// Project dimsPos to their collapsed positions in the reassocIndices.
823///
824/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
825/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
826/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
827/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
829projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
830 ArrayRef<ReassociationIndices> reassocIndices) {
831 SmallVector<int64_t> projectedPos;
832
833 // Map each dimension to the position of corresponding reassociation index.
834 for (auto pos : dimsPos) {
835 for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
836 // If the dimension is present in the current indices group, the group
837 // position within the reassociation map is the desired projected
838 // dimension position.
839 if (llvm::is_contained(indices, pos)) {
840 projectedPos.push_back(idx);
841 break;
842 }
843 }
844 }
845 assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
846
847 return projectedPos;
848}
849
850/// Bubble up pack op through expand shape op.
851///
852/// For example:
853///
854/// %expand = tensor.expand_shape %in [[0], [1, 2]]
855/// : tensor<?x64xf32> into tensor<?x4x16xf32>
856/// %pack = linalg.pack %expand outer_dims_perm = [0, 1]
857/// inner_dims_pos = [2] inner_tiles = [8] into %empty
858/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
859///
860/// can be transformed into:
861///
862/// %pack = linalg.pack %in outer_dims_perm = [1, 2]
863/// inner_dims_pos = [1] inner_tiles = [8] into %empty
864/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
865/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
866/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
867static LogicalResult
868bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
869 linalg::PackOp packOp,
870 PatternRewriter &rewriter) {
871 // Outer dimensions permutation is not supported currently.
872 // TODO: Handle outer_dims_perm variants.
873 ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
874 if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
875 return rewriter.notifyMatchFailure(packOp,
876 "non-identity outer dims perm NYI");
877 }
878
879 // Validate dimensions' relations between shape expansion and packing.
881 expandOp.getReassociationIndices();
882 ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
883 llvm::SetVector<int64_t> packDimsPos(llvm::from_range, packInnerDims);
884
885 for (auto [idx, indices] : llvm::enumerate(reassoc)) {
886 // For each expand_shape reassociation, figure out which dimensions get
887 // packed if any.
888 llvm::SetVector<int64_t> expandDimPos(llvm::from_range, indices);
889 llvm::SetVector<int64_t> packedDims =
890 llvm::set_intersection(packDimsPos, expandDimPos);
891
892 // The expanded dimension is not packed so, it does not affect moving pack
893 // before shape expansion - simply continue.
894 if (packedDims.empty())
895 continue;
896 // Shape expansion cannot be propagated when multiple expanded dimension are
897 // packed - in this case operation reordering would affect final element
898 // positions and/or shapes can no longer be projected.
899 if (packedDims.size() != 1)
900 return rewriter.notifyMatchFailure(
901 packOp, "only one of the expanded dimensions can be packed");
902 // Only the inner-most expanded dimension should be packed. Otherwise,
903 // elements order will be affected after operation reordering.
904 if (packedDims.front() != indices.back())
905 return rewriter.notifyMatchFailure(
906 packOp, "can only pack the inner-most expanded dimension");
907 }
908
909 // Project pack.inner_dims_pos to positions before shape expansion.
910 SmallVector<int64_t> projectedInnerDimsPos =
911 projectDimsPosIntoReassocPos(packInnerDims, reassoc);
912
913 // Project the shape expansion to new packed shape.
914 // The pack.outer_dims_perm is restricted to identity so, the permutation can
915 // be omitted for simplicity.
916 // TODO: Account for outer dimensions permutation.
917 //
918 // If reassociation is not possible, then reordering cannot happen.
919 // This can be caused by pack padding affecting previously expanded
920 // dimensions or packing extending dimensions.
921 RankedTensorType newPackType = linalg::PackOp::inferPackedType(
922 expandOp.getSrcType(), packOp.getStaticInnerTiles(),
923 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
924 auto reassocExpand =
925 getReassociationIndicesForReshape(newPackType, packOp.getDestType());
926 if (!reassocExpand)
927 return rewriter.notifyMatchFailure(
928 packOp, "could not reassociate dims after bubbling up");
929
930 Value destTensor = linalg::PackOp::createDestinationTensor(
931 rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
932 projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
933 Value packedVal = linalg::PackOp::create(
934 rewriter, packOp.getLoc(), expandOp.getSrc(), destTensor,
935 projectedInnerDimsPos, packOp.getMixedTiles(), packOp.getPaddingValue(),
936 /*outerDimsPerm=*/SmallVector<int64_t>{});
937
938 Value newExpandOp = tensor::ExpandShapeOp::create(rewriter, packOp.getLoc(),
939 packOp.getDestType(),
940 packedVal, *reassocExpand);
941 rewriter.replaceOp(packOp, newExpandOp);
942
943 return success();
944}
945
946class BubbleUpPackOpThroughReshapeOp final
947 : public OpRewritePattern<linalg::PackOp> {
948public:
949 BubbleUpPackOpThroughReshapeOp(MLIRContext *context, ControlPropagationFn fun)
950 : OpRewritePattern<linalg::PackOp>(context), controlFn(std::move(fun)) {}
951
952 LogicalResult matchAndRewrite(linalg::PackOp packOp,
953 PatternRewriter &rewriter) const override {
954 Operation *srcOp = packOp.getSource().getDefiningOp();
955 // Currently only support when the pack op is the only user.
956 if (!srcOp || !(srcOp->getNumResults() == 1) ||
957 !srcOp->getResult(0).hasOneUse()) {
958 return failure();
959 }
960 // Currently only support static inner tile sizes.
961 if (llvm::any_of(packOp.getStaticTiles(), ShapedType::isDynamic))
962 return failure();
963
964 // User controlled propagation function.
965 if (!controlFn(&packOp.getSourceMutable()))
966 return failure();
967
969 .Case([&](tensor::CollapseShapeOp op) {
970 return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
971 })
972 .Case([&](tensor::ExpandShapeOp op) {
973 return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
974 })
975 .Default([](Operation *) { return failure(); });
976 }
977
978private:
979 ControlPropagationFn controlFn;
980};
981
982/// Push down unpack op through expand shape op when the packed dims can be
983/// projected to the dims after expanding. This is possible when the inner tile
984/// sizes can divide the projected dims.
985///
986/// For example:
987///
988/// %unpack = linalg.unpack %in outer_dims_perm = [0, 1]
989/// inner_dims_pos = [0, 1] inner_tiles = [8, 8] into %empty
990/// : tensor<?x32x8x8xf32> -> tensor<?x256xf32>
991/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
992/// : tensor<?x256xf32> into tensor<?x256x256xf32>
993///
994/// can be transformed into:
995///
996/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
997/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
998/// %unpack = linalg.unpack %expanded outer_dims_perm = [0, 1, 2]
999/// inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %empty
1000/// : tensor<?x32x32x8x8xf32> -> tensor<?x256x256xf32>
1001static LogicalResult pushDownUnPackOpThroughExpandShape(
1002 linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp,
1003 PatternRewriter &rewriter, ControlPropagationFn controlFn) {
1004 // User controlled propagation function.
1005 if (!controlFn(&expandOp.getSrcMutable()))
1006 return failure();
1007
1008 SmallVector<int64_t> innerTileSizes = unPackOp.getStaticTiles();
1009 ArrayRef<int64_t> innerDimsPos = unPackOp.getInnerDimsPos();
1010 ArrayRef<int64_t> outerDimsPerm = unPackOp.getOuterDimsPerm();
1011
1012 auto expandTy = dyn_cast<RankedTensorType>(expandOp.getType());
1013 if (!expandTy)
1014 return failure();
1015 ArrayRef<int64_t> dstShape = expandTy.getShape();
1016 SmallVector<ReassociationIndices> reassocIndices =
1017 expandOp.getReassociationIndices();
1018 // Project inner tile pos to the dim pos after expanding. For example, if dims
1019 // [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
1020 // on dim y.
1021 //
1022 // Project to inner-most non-unit dims to increase the chance that they can be
1023 // divided by the inner tile sizes. This is correct because for [..., x, 1],
1024 // unpacking on dim 1 is equivalent to unpacking on dim x.
1025 SmallVector<int64_t> projectedInnerDimsPos =
1026 projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
1027
1028 if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
1029 innerTileSizes)) {
1030 return failure();
1031 }
1032 // Expand the outer dims permutation with the associated expanded dims for the
1033 // new permutation after pushing. This is because moving a source dim is
1034 // equivalent to moving the associated expanded dims together.
1035 SmallVector<int64_t> newOuterDimsPerm;
1036 for (auto outerPos : outerDimsPerm)
1037 llvm::append_range(newOuterDimsPerm, reassocIndices[outerPos]);
1038
1039 SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
1040 // First apply the permutation on the reassociations of the outer dims.
1041 // For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
1042 // -> [[0], [1, 2]]
1043 int64_t nextPos =
1044 applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
1045 // Then add direct mapping for the inner tile dims.
1046 for (size_t i = 0; i < innerDimsPos.size(); ++i) {
1047 newReassocIndices.push_back({nextPos});
1048 nextPos += 1;
1049 }
1050
1051 RankedTensorType newExpandType = linalg::PackOp::inferPackedType(
1052 expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm);
1053 auto newExpandOp =
1054 tensor::ExpandShapeOp::create(rewriter, expandOp.getLoc(), newExpandType,
1055 unPackOp.getSource(), newReassocIndices);
1056
1057 auto emptyOp = linalg::UnPackOp::createDestinationTensor(
1058 rewriter, unPackOp.getLoc(), newExpandOp, unPackOp.getMixedTiles(),
1059 projectedInnerDimsPos, newOuterDimsPerm);
1060 auto newUnPackOp = linalg::UnPackOp::create(
1061 rewriter, unPackOp.getLoc(), newExpandOp.getResult(), emptyOp,
1062 projectedInnerDimsPos, unPackOp.getMixedTiles(), newOuterDimsPerm);
1063 rewriter.replaceOp(expandOp, newUnPackOp);
1064
1065 return success();
1066}
1067
1068class PushDownUnPackOpThroughReshapeOp final
1069 : public OpRewritePattern<linalg::UnPackOp> {
1070public:
1071 PushDownUnPackOpThroughReshapeOp(MLIRContext *context,
1073 : OpRewritePattern<linalg::UnPackOp>(context), controlFn(std::move(fun)) {
1074 }
1075
1076 LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp,
1077 PatternRewriter &rewriter) const override {
1078 Value result = unPackOp.getResult();
1079 // Currently only support unpack op with the single user.
1080 if (!result.hasOneUse()) {
1081 return failure();
1082 }
1083 // Currently only support static inner tile sizes.
1084 if (llvm::any_of(unPackOp.getStaticTiles(), ShapedType::isDynamic))
1085 return failure();
1086
1087 Operation *consumerOp = *result.user_begin();
1089 .Case([&](tensor::ExpandShapeOp op) {
1090 return pushDownUnPackOpThroughExpandShape(unPackOp, op, rewriter,
1091 controlFn);
1092 })
1093 .Default([](Operation *) { return failure(); });
1094 }
1095
1096private:
1097 ControlPropagationFn controlFn;
1098};
1099
1100// TODO: Relax this restriction. We should unpack a generic op also
1101// in the presence of multiple unpack ops as producers.
1102/// Return the unpacked operand, if present, for the current generic op.
1103static FailureOr<OpOperand *> getUnPackedOperand(GenericOp genericOp) {
1104 OpOperand *unPackedOperand = nullptr;
1105 for (OpOperand &operand : genericOp->getOpOperands()) {
1106 auto unPackOp = operand.get().getDefiningOp<linalg::UnPackOp>();
1107 if (!unPackOp)
1108 continue;
1109 if (unPackedOperand)
1110 return failure();
1111 unPackedOperand = &operand;
1112 }
1113 if (!unPackedOperand)
1114 return failure();
1115 return unPackedOperand;
1116}
1117
1118/// Push down a linalg.unpack op through a generic op.
1119/// The new generic op works on packed domain; pack ops are created for input
1120/// and output operands. A linalg.unpack op is inserted right after the packed
1121/// generic. E.g.
1122///
1123/// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
1124///
1125/// %arg0 = tensor<12x2x56x56x32xf32> // packed arg.
1126///
1127/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1128/// %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2]
1129/// inner_dims_pos = [3] inner_tiles = [32] into %0
1130/// %2 = linalg.generic {indexing_maps = [#map],
1131/// iterator_types = ["parallel", "parallel", "parallel", "parallel"]}
1132/// outs(%1 : tensor<12x56x56x64xf32>) {
1133/// ^bb0(%out : f32):
1134/// linalg.yield %out : f32
1135/// } -> tensor<12x56x56x64xf32>
1136///
1137/// will be converted to
1138///
1139/// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
1140///
1141/// %0 = tensor.empty() : tensor<12x56x56x64xf32>
1142/// %1 = linalg.generic {indexing_maps = [#map],
1143/// iterator_types = ["parallel", "parallel", "parallel",
1144/// "parallel", "parallel"]}
1145/// outs(%arg0 : tensor<12x2x56x56x32xf32>) {
1146/// ^bb0(%out : f32):
1147/// linalg.yield %out : f32
1148/// } -> tensor<12x2x56x56x32xf32>
1149/// %2 = linalg.unpack %1 outer_dims_perm = [0, 3, 1, 2]
1150/// inner_dims_pos = [3] inner_tiles = [32] into %0
1151///
1152static FailureOr<std::tuple<GenericOp, Value>>
1153pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
1154 ControlPropagationFn controlFn,
1155 bool poisonPaddingOk) {
1156 if (genericOp.getNumResults() != 1)
1157 return failure();
1158
1159 if (hasGatherSemantics(genericOp))
1160 return failure();
1161
1162 // Collect the unPacked operand, if present.
1163 auto maybeUnPackedOperand = getUnPackedOperand(genericOp);
1164 if (failed(maybeUnPackedOperand))
1165 return failure();
1166 OpOperand *unPackedOperand = *(maybeUnPackedOperand);
1167
1168 // Extract packing information.
1169 linalg::UnPackOp producerUnPackOp =
1170 unPackedOperand->get().getDefiningOp<linalg::UnPackOp>();
1171 assert(producerUnPackOp && "expect a valid UnPackOp");
1172
1173 if (!controlFn(unPackedOperand))
1174 return failure();
1175
1176 auto packInfo =
1177 getPackingInfoFromOperand(unPackedOperand, genericOp, producerUnPackOp);
1178 if (failed(packInfo))
1179 return failure();
1180
1181 // Rebuild the indexing map for the corresponding init operand.
1183 bool requiresPadding =
1184 getPackedOperandDetails(rewriter, *packInfo, genericOp,
1185 genericOp.getDpsInitOperand(0), packedOperandMap);
1186 if (requiresPadding && !poisonPaddingOk)
1187 return failure();
1188
1189 auto [packedOutOperand, packedOutIndexingMap] =
1190 getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(),
1191 genericOp.getDpsInitOperand(0),
1192 packedOperandMap);
1193 auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
1194
1195 // Forward the new tensor.empty as a destination if it is one of the following
1196 // situations:
1197 // 1) The dps init operand is a tensor.empty.
1198 // 2) The dps init is a write-only operand, i.e., it is not used in the
1199 // genericOp
1200 Value dest = packedOutOperand;
1201 auto initTensor =
1202 genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
1203 if (initTensor || isGenericOutsNotUsed(genericOp)) {
1204 if (destPack)
1205 dest = destPack.getDest();
1206 }
1207
1208 // Pack the genericOp.
1209 // pack(unpack) is foldable in this case. This is because in pushing down the
1210 // unpack, by default we will populate an additional pack op after the unpack.
1211 // This guarantees them to be foldable.
1212 auto maybeGenericOp =
1213 packGenericOp(rewriter, genericOp, dest, packedOutIndexingMap, *packInfo,
1214 /*isFoldableUnpackPack=*/true, poisonPaddingOk);
1215 if (failed(maybeGenericOp))
1216 return failure();
1217 GenericOp newGenericOp = *maybeGenericOp;
1218 Value newResult =
1219 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0));
1220
1221 // If the output is unaffected, no need to unpack.
1222 if (!destPack)
1223 return std::make_tuple(newGenericOp, newResult);
1224
1225 auto mixedTiles = destPack.getMixedTiles();
1226 auto innerDimsPos = destPack.getInnerDimsPos();
1227 auto outerDimsPerm = destPack.getOuterDimsPerm();
1228
1229 // Insert an unPackOp right after the packed generic.
1230 Value unPackOpRes =
1231 linalg::UnPackOp::create(rewriter, genericOp.getLoc(), newResult,
1232 destPack.getSource(), innerDimsPos, mixedTiles,
1233 outerDimsPerm)
1234 .getResult();
1235
1236 return std::make_tuple(newGenericOp, unPackOpRes);
1237}
1238
1239// Wrapper pattern that applies pushDownUnPackOpThroughGenericOp method.
1240struct PushDownUnPackOpThroughGenericOp : public OpRewritePattern<GenericOp> {
1241public:
1242 PushDownUnPackOpThroughGenericOp(MLIRContext *context,
1244 bool poisonPaddingOk)
1245 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)),
1246 poisonPaddingOk(std::move(poisonPaddingOk)) {}
1247
1248 LogicalResult matchAndRewrite(GenericOp genericOp,
1249 PatternRewriter &rewriter) const override {
1250 auto genericAndRepl = pushDownUnPackOpThroughGenericOp(
1251 rewriter, genericOp, controlFn, poisonPaddingOk);
1252 if (failed(genericAndRepl))
1253 return failure();
1254 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1255 return success();
1256 }
1257
1258private:
1259 ControlPropagationFn controlFn;
1260 bool poisonPaddingOk;
1261};
1262
1263/// Propagate a linalg.unpack operation through a tensor.pad. The idea is to
1264/// add as many zero padding dimensions in `high` and `low` based on the number
1265/// of point loops.
1266struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
1267 PushDownUnPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
1268 : OpRewritePattern<tensor::PadOp>(context), controlFn(std::move(fun)) {}
1269
1270 LogicalResult matchAndRewrite(tensor::PadOp padOp,
1271 PatternRewriter &rewriter) const override {
1272 linalg::UnPackOp unpackOp =
1273 padOp.getSource().getDefiningOp<linalg::UnPackOp>();
1274 if (!unpackOp)
1275 return failure();
1276
1277 if (!controlFn(&padOp.getSourceMutable()))
1278 return failure();
1279
1280 Location loc = padOp.getLoc();
1281 // Bail out if one of the padded dimension is a tiled one.
1282 llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
1283 ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1284 llvm::SmallBitVector innerDims(paddedDims.size());
1285 for (int64_t dim : innerDimsPos)
1286 innerDims.flip(dim);
1287 if (paddedDims.anyCommon(innerDims))
1288 return failure();
1289
1290 Value paddingVal = padOp.getConstantPaddingValue();
1291 if (!paddingVal)
1292 return failure();
1293
1294 // If we have `outer_dims_perms` we need to adjust the padded dimensions.
1295 ArrayRef<int64_t> outerDimsPerm = unpackOp.getOuterDimsPerm();
1296 SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
1297 SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
1298 if (!outerDimsPerm.empty()) {
1299 applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
1300 applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
1301 }
1302 // Add zero padding for the point loops.
1303 size_t pointLoopsSize = innerDimsPos.size();
1304 lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1305 highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
1306
1307 auto newPadOp = tensor::PadOp::create(rewriter, loc, /*result=*/Type(),
1308 unpackOp.getSource(), lowPad, highPad,
1309 paddingVal, padOp.getNofold());
1310
1311 // Inject the linalg.unpack right after the packed padOp.
1312 Value outputUnPack =
1313 tensor::EmptyOp::create(rewriter, loc, padOp.getResultType().getShape(),
1314 padOp.getResultType().getElementType());
1315
1316 Value replacement = linalg::UnPackOp::create(
1317 rewriter, loc, newPadOp.getResult(), outputUnPack, innerDimsPos,
1318 unpackOp.getMixedTiles(), outerDimsPerm);
1319 rewriter.replaceOp(padOp, replacement);
1320 return success();
1321 }
1322
1323private:
1324 ControlPropagationFn controlFn;
1325};
1326
1327// This struct contains infomation about extract_slice dims.
1328struct SliceDimInfo {
1329 OpFoldResult offset;
1330 OpFoldResult sliceSize;
1331 OpFoldResult outputSize;
1332};
1333
1334/// Return all extract slice operands, if present, for the current
1335/// generic op.
1336static FailureOr<SmallVector<OpOperand *>>
1337getSliceOperands(GenericOp genericOp) {
1338 SmallVector<OpOperand *> sliceOperands;
1339 for (auto operand : genericOp.getDpsInputOperands()) {
1340 auto extractOp = operand->get().getDefiningOp<tensor::ExtractSliceOp>();
1341 if (!extractOp)
1342 continue;
1343 sliceOperands.push_back(operand);
1344 }
1345 if (sliceOperands.empty()) {
1346 return failure();
1347 }
1348 return sliceOperands;
1349}
1350
1351// Return a map of dims that have partial slices on them so that other operands
1352// can use this information. Also return a bool mentioning if a reduction dim
1353// has a non full slice as that can be used to fold the original extract slice.
1354static FailureOr<llvm::DenseMap<int64_t, SliceDimInfo>>
1355getPartialSliceDimInfo(GenericOp genericOp, OpOperand *sliceOperand) {
1356 tensor::ExtractSliceOp producerSliceOp =
1357 sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1358 assert(producerSliceOp && "expect a valid ExtractSliceOp");
1359 llvm::DenseMap<int64_t, SliceDimInfo> partialSliceDimMap;
1360 SmallVector<OpFoldResult> offsets = producerSliceOp.getMixedOffsets();
1361 SmallVector<OpFoldResult> sizes = producerSliceOp.getMixedSizes();
1362
1364 genericOp.getContext(), producerSliceOp.getSourceType().getShape());
1365
1366 for (auto [idx, expr] : llvm::enumerate(
1367 genericOp.getMatchingIndexingMap(sliceOperand).getResults())) {
1368 // If we have a full slice in a dimension then we dont need to add it to
1369 // the partial slice map.
1370 if (isConstantIntValue(offsets[idx], 0) &&
1371 isEqualConstantIntOrValue(sizes[idx], shape[idx])) {
1372 continue;
1373 }
1374 // We only support partial slices of AffineDimExprs so bail-out if thats not
1375 // the case.
1376 if (!isa<AffineDimExpr>(expr)) {
1377 return failure();
1378 }
1379 SliceDimInfo sliceDimInfo{offsets[idx], sizes[idx], shape[idx]};
1380 int64_t dimPos = cast<AffineDimExpr>(expr).getPosition();
1381 partialSliceDimMap[dimPos] = sliceDimInfo;
1382 }
1383 // Next check if the dims with partial slice info are used in non
1384 // AffineDimExpr in other operands and if they are then bail-out.
1385 for (OpOperand &operand : genericOp->getOpOperands()) {
1386 if (operand == *sliceOperand) {
1387 continue;
1388 }
1389 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(&operand);
1390 if (llvm::any_of(IndexingMap.getResults(), [&](AffineExpr expr) {
1391 if (isa<AffineDimExpr>(expr)) {
1392 return false;
1393 }
1394 WalkResult status = expr.walk([&](AffineExpr expr) {
1395 if (auto dimExpr = dyn_cast<AffineDimExpr>(expr)) {
1396 if (partialSliceDimMap.contains(dimExpr.getPosition())) {
1397 return WalkResult::interrupt();
1398 }
1399 }
1400 return WalkResult::advance();
1401 });
1402 if (status.wasInterrupted()) {
1403 return true;
1404 }
1405 return false;
1406 })) {
1407 return failure();
1408 }
1409 }
1410 return partialSliceDimMap;
1411}
1412
1413static FailureOr<std::tuple<GenericOp, Value>>
1414pushDownExtractSliceOpThroughGenericOp(RewriterBase &rewriter,
1415 GenericOp genericOp,
1416 ControlPropagationFn controlFn) {
1417 if (genericOp.getNumResults() != 1)
1418 return rewriter.notifyMatchFailure(
1419 genericOp, "propagation through multi-result generic is unsupported.");
1420 if (hasGatherSemantics(genericOp))
1421 return rewriter.notifyMatchFailure(
1422 genericOp,
1423 "propagation through generic with gather semantics is unsupported.");
1424 // Collect the sliced operand, if present.
1425 auto maybeSliceOperands = getSliceOperands(genericOp);
1426 if (failed(maybeSliceOperands))
1427 return failure();
1428 SmallVector<OpOperand *> sliceOperands = *maybeSliceOperands;
1429 OpOperand *sliceOperand;
1430
1431 bool foundValidOperand = false;
1432 for (auto currSliceOperand : sliceOperands) {
1433 if (controlFn(currSliceOperand)) {
1434 sliceOperand = currSliceOperand;
1435 foundValidOperand = true;
1436 break;
1437 }
1438 }
1439 if (!foundValidOperand) {
1440 return failure();
1441 }
1442 unsigned OperandIndex = sliceOperand->getOperandNumber();
1443
1444 tensor::ExtractSliceOp producerSliceOp =
1445 sliceOperand->get().getDefiningOp<tensor::ExtractSliceOp>();
1446 assert(producerSliceOp && "expect a valid ExtractSliceOp");
1447
1448 if (producerSliceOp.getSource().getType().getRank() !=
1449 producerSliceOp.getResult().getType().getRank()) {
1450 return rewriter.notifyMatchFailure(
1451 genericOp,
1452 "propagation of rank-reducing extract slice is unsupported.");
1453 }
1454
1455 SmallVector<OpFoldResult> strides = producerSliceOp.getMixedStrides();
1456 if (!areAllConstantIntValue(strides, 1))
1457 return rewriter.notifyMatchFailure(
1458 genericOp, "propagation of strided extract slice is unsupported.");
1459
1460 // check if we can support the propagation of this extractSlice
1461 // through the generic op and if so return the dimensions that
1462
1463 auto maybePartialSliceDimMap =
1464 getPartialSliceDimInfo(genericOp, sliceOperand);
1465
1466 if (failed(maybePartialSliceDimMap)) {
1467 return failure();
1468 }
1469
1470 auto partialSliceDimMap = *maybePartialSliceDimMap;
1471
1473 genericOp.getIteratorTypesArray();
1474 bool hasPartialReductionDimSlice =
1475 llvm::any_of(partialSliceDimMap, [&](const auto &slice) {
1476 int64_t sliceDim = slice.first;
1477 return iterators[sliceDim] == utils::IteratorType::reduction;
1478 });
1479
1480 // Store the padding information as (dimPos, lowPad, highPad, PaddedShape).
1481 Location loc = genericOp->getLoc();
1482 AffineExpr dim0, dim1;
1483 bindDims(rewriter.getContext(), dim0, dim1);
1484 auto subMap = AffineMap::get(2, 0, {dim0 - dim1});
1485 auto sub = [&](OpFoldResult v1, OpFoldResult v2) {
1486 return affine::makeComposedFoldedAffineApply(rewriter, loc, subMap,
1487 {v1, v2});
1488 };
1489
1490 MLIRContext *ctx = genericOp.getContext();
1491 SmallVector<Value> paddedInputs;
1492 for (auto [idx, operand] : llvm::enumerate(genericOp.getDpsInputOperands())) {
1493 if (idx == OperandIndex && !hasPartialReductionDimSlice) {
1494 paddedInputs.push_back(producerSliceOp.getSource());
1495 continue;
1496 }
1497 AffineMap IndexingMap = genericOp.getMatchingIndexingMap(operand);
1498 if (IndexingMap.getNumResults() == 0) {
1499 paddedInputs.push_back(operand->get());
1500 continue;
1501 }
1502 SmallVector<OpFoldResult> operandLowPads(IndexingMap.getNumResults(),
1503 getAsIndexOpFoldResult(ctx, 0));
1504 SmallVector<OpFoldResult> operandHighPads(IndexingMap.getNumResults(),
1505 getAsIndexOpFoldResult(ctx, 0));
1506 for (auto [idx, expr] : llvm::enumerate(IndexingMap.getResults())) {
1507 if (!isa<AffineDimExpr>(expr)) {
1508 continue;
1509 }
1510 AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1511 if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1512 continue;
1513 }
1514 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1515 operandLowPads[idx] = sliceDimInfo.offset;
1516 operandHighPads[idx] =
1517 sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1518 sliceDimInfo.sliceSize);
1519 }
1520 auto paddingValue = ub::PoisonOp::create(
1521 rewriter, loc, getElementTypeOrSelf(operand->get().getType()));
1522 auto paddedOperand = tensor::PadOp::create(
1523 rewriter, loc, Type(), operand->get(), operandLowPads, operandHighPads,
1524 paddingValue, /*nofold=*/false);
1525 paddedInputs.push_back(paddedOperand);
1526 }
1527 AffineMap outputIndexingMap =
1528 genericOp.getMatchingIndexingMap(genericOp.getDpsInitOperand(0));
1529
1530 auto outputShapeType =
1531 llvm::cast<ShapedType>(genericOp.getDpsInitOperand(0)->get().getType());
1532 SmallVector<OpFoldResult> OutputShape = llvm::map_to_vector(
1533 outputShapeType.getShape(),
1534 [&](int64_t sz) -> OpFoldResult { return rewriter.getIndexAttr(sz); });
1535 SmallVector<OpFoldResult> newSizes = OutputShape;
1536 SmallVector<OpFoldResult> outputLowPads(outputIndexingMap.getNumResults(),
1537 getAsIndexOpFoldResult(ctx, 0));
1538 SmallVector<OpFoldResult> outputHighPads(outputIndexingMap.getNumResults(),
1539 getAsIndexOpFoldResult(ctx, 0));
1540 SmallVector<OpFoldResult> newStrides(outputIndexingMap.getNumResults(),
1541 getAsIndexOpFoldResult(ctx, 1));
1542 for (auto [idx, expr] : llvm::enumerate(outputIndexingMap.getResults())) {
1543 if (!isa<AffineDimExpr>(expr)) {
1544 continue;
1545 }
1546 AffineDimExpr dimExpr = cast<AffineDimExpr>(expr);
1547 if (!partialSliceDimMap.contains(dimExpr.getPosition())) {
1548 continue;
1549 }
1550 SliceDimInfo sliceDimInfo = partialSliceDimMap[dimExpr.getPosition()];
1551 outputLowPads[idx] = sliceDimInfo.offset;
1552 outputHighPads[idx] = sub(sub(sliceDimInfo.outputSize, sliceDimInfo.offset),
1553 sliceDimInfo.sliceSize);
1554 OutputShape[idx] = sliceDimInfo.outputSize;
1555 newSizes[idx] = sliceDimInfo.sliceSize;
1556 }
1557 Value newPadOutput;
1558 auto outputElType =
1559 getElementTypeOrSelf(genericOp.getDpsInits()[0].getType());
1560 if (isGenericOutsNotUsed(genericOp)) {
1561 newPadOutput =
1562 tensor::EmptyOp::create(rewriter, loc, OutputShape, outputElType);
1563 } else {
1564 auto paddingValue = ub::PoisonOp::create(rewriter, loc, outputElType);
1565 newPadOutput = tensor::PadOp::create(
1566 rewriter, loc, Type(), genericOp.getDpsInits()[0], outputLowPads,
1567 outputHighPads, paddingValue, /*nofold=*/false);
1568 }
1569
1570 auto newGenericOp = linalg::GenericOp::create(
1571 rewriter, loc, newPadOutput.getType(), paddedInputs, {newPadOutput},
1572 genericOp.getIndexingMapsArray(), genericOp.getIteratorTypesArray(),
1573 /*bodyBuild=*/nullptr, linalg::getPrunedAttributeList(genericOp));
1574 rewriter.cloneRegionBefore(genericOp.getRegion(), newGenericOp.getRegion(),
1575 newGenericOp.getRegion().begin());
1576
1577 auto extractOp = tensor::ExtractSliceOp::create(
1578 rewriter, loc,
1579 newGenericOp.getTiedOpResult(newGenericOp.getDpsInitOperand(0)),
1580 outputLowPads, newSizes, newStrides);
1581 Value extractRes = extractOp.getResult();
1582
1583 return std::make_tuple(newGenericOp, extractRes);
1584}
1585
1586class PushDownExtractSliceOpThroughGenericOp final
1587 : public OpRewritePattern<GenericOp> {
1588public:
1589 PushDownExtractSliceOpThroughGenericOp(MLIRContext *context,
1591 : OpRewritePattern<GenericOp>(context), controlFn(std::move(fun)) {}
1592
1593 LogicalResult matchAndRewrite(GenericOp genericOp,
1594 PatternRewriter &rewriter) const override {
1595 auto genericAndRepl =
1596 pushDownExtractSliceOpThroughGenericOp(rewriter, genericOp, controlFn);
1597 if (failed(genericAndRepl))
1598 return failure();
1599 rewriter.replaceOp(genericOp, std::get<1>(*genericAndRepl));
1600 return success();
1601 }
1602
1603private:
1604 ControlPropagationFn controlFn;
1605};
1606
1607} // namespace
1608
1611 const ControlPropagationFn &controlPackUnPackPropagation,
1612 bool PoisonPaddingOk) {
1613 patterns.insert<BubbleUpPackThroughPadOp, BubbleUpPackOpThroughReshapeOp,
1614 PushDownUnPackThroughPadOp, PushDownUnPackOpThroughReshapeOp>(
1615 patterns.getContext(), controlPackUnPackPropagation);
1616 patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
1617 PushDownUnPackOpThroughGenericOp>(
1618 patterns.getContext(), controlPackUnPackPropagation, PoisonPaddingOk);
1619}
1620
1623 const ControlPropagationFn &controlPackUnPackPropagation) {
1624 patterns.insert<PushDownExtractSliceOpThroughGenericOp>(
1625 patterns.getContext(), controlPackUnPackPropagation);
1626}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
*if copies could not be generated due to yet unimplemented cases *copyInPlacementStart and copyOutPlacementStart in copyPlacementBlock *specify the insertion points where the incoming copies and outgoing should be the output argument nBegin is set to its * replacement(set to `begin` if no invalidation happens). Since outgoing *copies could have been inserted at `end`
A dimensional identifier appearing in an affine expression.
Definition AffineExpr.h:223
unsigned getPosition() const
Base type for affine expression.
Definition AffineExpr.h:68
RetT walk(FnT &&callback) const
Walk all of the AffineExpr's in this expression in postorder.
Definition AffineExpr.h:117
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition AffineMap.h:46
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
ArrayRef< AffineExpr > getResults() const
unsigned getNumResults() const
IntegerAttr getIndexAttr(int64_t value)
Definition Builders.cpp:108
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition Builders.h:91
MLIRContext * getContext() const
Definition Builders.h:56
A class for computing basic dominance information.
Definition Dominance.h:140
bool properlyDominates(Operation *a, Operation *b, bool enclosingOpOk=true) const
Return true if operation A properly dominates operation B, i.e.
IRValueT get() const
Return the current value being used by this operand.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition Builders.h:398
void cloneRegionBefore(Region &region, Region &parent, Region::iterator before, IRMapping &mapping)
Clone the blocks that belong to "region" before the given position in another region "parent".
Definition Builders.cpp:589
This class represents a single result from folding an operation.
This class represents an operand of an operation.
Definition Value.h:257
unsigned getOperandNumber()
Return which operand this is in the OpOperand list of the Operation.
Definition Value.cpp:226
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition Operation.h:407
unsigned getNumResults()
Return the number of results held by this operation.
Definition Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesExcept(Value from, Value to, Operation *exceptedUser)
Find uses of from and replace them with to except if the user is exceptedUser.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult advance()
Definition WalkResult.h:47
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
void populateDataLayoutPropagationPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation, bool PoisonPaddingOk=false)
Patterns to bubble up or down data layout ops across other operations.
void populateExtractSliceSinkingPatterns(RewritePatternSet &patterns, const ControlPropagationFn &controlPackUnPackPropagation)
Patterns to sink extract slice across other operations.
bool isParallelIterator(utils::IteratorType iteratorType)
Check if iterator type has "parallel" semantics.
Definition Utils.cpp:230
std::function< bool(OpOperand *opOperand)> ControlPropagationFn
Function type which is used to control propagation of linalg.pack/unpack ops.
SmallVector< NamedAttribute > getPrunedAttributeList(OpTy op)
Returns an attribute list that excludes pre-defined attributes.
Definition Utils.h:388
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
bool isConstantIntValue(OpFoldResult ofr, int64_t value)
Return true if ofr is constant integer equal to value.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
bool areAllConstantIntValue(ArrayRef< OpFoldResult > ofrs, int64_t value)
Return true if all of ofrs are constant integers equal to value.
bool isEqualConstantIntOrValue(OpFoldResult ofr1, OpFoldResult ofr2)
Return true if ofr1 and ofr2 are the same integer constant attribute values or the same SSA value.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition AffineExpr.h:311
std::optional< SmallVector< ReassociationIndices > > getReassociationIndicesForReshape(ShapedType sourceType, ShapedType targetType)
Return the reassociations maps to use to reshape given the source type and the target type when possi...
bool isIdentityPermutation(ArrayRef< int64_t > permutation)
Returns true if permutation is an identity permutation.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
llvm::DenseMap< KeyT, ValueT, KeyInfoT, BucketT > DenseMap
Definition LLVM.h:126
SmallVector< int64_t, 2 > ReassociationIndices
Definition Utils.h:27
void applyPermutationToVector(SmallVector< T, N > &inVec, ArrayRef< int64_t > permutation)
Apply the permutation defined by permutation to inVec.
SmallVector< int64_t > invertPermutationVector(ArrayRef< int64_t > permutation)
Helper method to apply to inverse a permutation.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...