MLIR  20.0.0git
TosaReduceTransposes.cpp
Go to the documentation of this file.
1 //===- TosaReduceTransposes.cpp -------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 // ----------
10 // Motivation:
11 // ----------
12 
13 // Some legalization pathways introduce redundant tosa.TRANSPOSE
14 // operations that result in avoidable data movement. For example,
15 // PyTorch -> TOSA contains a lot of unnecessary transposes due
16 // to conversions between NCHW and NHWC.
17 
18 // We wish to remove all the ones that we can, since in general
19 // it is possible to remove the overwhelming majority.
20 
21 // -------------------
22 // High-Level Overview:
23 // -------------------
24 
25 // The pass works through the transpose operators in the program. It begins at
26 // some transpose operator with an associated permutations tensor. It traverses
27 // upwards through the dependencies of this transpose and verifies that we
28 // encounter only operators with the TosaElementwiseOperator trait and terminate
29 // in either constants, reshapes, or transposes.
30 
31 // We then evaluate whether there are any additional restrictions (the
32 // transposes it terminates in must invert the one we began at, and the reshapes
33 // must be ones in which we can fold the transpose into), and then we hoist the
34 // transpose through the intervening operators, folding it at the constants,
35 // reshapes, and transposes.
36 
37 // Finally, we ensure that we do not need both the transposed form (the form
38 // that had the transpose hoisted through it) and the untransposed form (which
39 // it was prior), by analyzing the usages of those dependent operators of a
40 // given transpose we are attempting to hoist and replace.
41 
42 // If they are such that it would require both forms to be necessary, then we do
43 // not replace the hoisted transpose, causing the new chain to be dead.
44 // Otherwise, we do and the old chain (untransposed form) becomes dead. Only one
45 // chain will ever then be live, resulting in no duplication.
46 
47 // We then perform a simple one-pass DCE, so no canonicalization is necessary.
48 
49 // -----------
50 // Future Work:
51 // -----------
52 
53 // (1) Evaluate tradeoffs with permitting ConstOp to be duplicated across
54 // hoisted
55 // transposes with different permutation tensors.
56 
57 // (2) Expand the class of foldable upstream ReshapeOp we permit beyond
58 // N -> 1x1x...x1xNx1x...x1x1.
59 
60 // (3) Enchance the pass to permit folding arbitrary transpose pairs, beyond
61 // those that form the identity.
62 
63 // (4) Add support for more instructions besides TosaElementwiseOperator as
64 // the intervening ones (for example, the reduce_* operators).
65 
66 // (5) Support hoisting transposes up to an input parameter.
67 
68 //===----------------------------------------------------------------------===//
69 
74 #include "mlir/IR/Iterators.h"
75 #include "mlir/IR/Matchers.h"
76 #include "llvm/ADT/TypeSwitch.h"
77 #include <memory>
78 #include <set>
79 #include <stack>
80 
81 namespace mlir {
82 namespace tosa {
83 #define GEN_PASS_DEF_TOSAREDUCETRANSPOSES
84 #include "mlir/Dialect/Tosa/Transforms/Passes.h.inc"
85 } // namespace tosa
86 } // namespace mlir
87 
88 using namespace mlir;
89 using namespace mlir::tosa;
90 
91 //===----------------------------------------------------------------------===//
92 // TOSA Reduce Transposes Pass.
93 //===----------------------------------------------------------------------===//
94 
95 namespace {
96 
97 struct TosaReduceTransposes final
98  : public tosa::impl::TosaReduceTransposesBase<TosaReduceTransposes> {
99  void runOnOperation() override;
100 
101 private:
102  // This will collect all the data dependencies for the given Operation
103  // up to and including ConstOp, ReshapeOp, and TransposeOp.
104  bool collectFanIn(Operation *op, SetVector<Operation *> &collected);
105  bool convertDependentOps(SetVector<Operation *> &dependentOps,
106  DenseMap<Value, Value> &valuesMap,
107  IRRewriter &rewriter,
108  ArrayRef<int32_t> hoistedPerms);
109 
110  // Checks if the two permutations, when applied consecutively, result
111  // in the identity.
112  bool areInvolutionTransposes(ArrayRef<int32_t> perms1,
113  ArrayRef<int32_t> perms2);
114 
115  // This is meant to apply to operations with the TosaElementwiseOperator
116  // trait.
117  std::optional<Value>
118  buildMappedToValue(Operation *op, const DenseMap<Value, Value> &valuesMap,
119  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
120 
121  // This updates valuesMap when we encounter another TransposeOp as a
122  // dependency of the hoisted one. %0 = tosa.transpose %arg0 <- applies to
123  // this %1 = tosa.transpose %0 <- when tracking back from this
124  std::optional<Value>
125  buildMappedToValue(TransposeOp transposeOp,
126  const DenseMap<Value, Value> &valuesMap,
127  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
128 
129  // Checks if ReshapeOp can have hoisted TransposeOp folded into it. If so,
130  // it creates new ReshapeOp with that fold.
131  std::optional<Value>
132  buildMappedToValue(ReshapeOp reshapeOp,
133  const DenseMap<Value, Value> &valuesMap,
134  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
135 
136  // We may have something like:
137  // %0 = tosa.const
138  // %1 = tosa.transpose
139  // %2 = tosa.add %0, %1
140  // %3 = tosa.transpose %2
141  // that --tosa-layerwise-const-fold wouldn't handle. This use shows up
142  // in MobilenetV3.
143  std::optional<Value>
144  buildMappedToValue(ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
145  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms);
146 
147  // Checks which TransposeOp we should "replace", turning their converted
148  // chains of ops, through which they were propagated, "live", and the old code
149  // "dead." Attempts to avoid doing so when doing so would result in the old
150  // code staying "live," resulting in duplication.
151  std::set<TransposeOp> getGoodReplacements(
152  ArrayRef<int32_t> perms,
153  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
154  &transposeInfo);
155 
156  // Helper function for dependenciesAreValid.
157  bool userNotContainedInValidTransposeDependencies(
158  Operation *user, std::set<TransposeOp> &validTransposes,
159  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
160  &transposeInfo);
161 
162  // Helper function for getGoodReplacements to check if some TransposeOp's
163  // dependencies are OK.
164  bool dependenciesAreValid(
165  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
166  std::set<TransposeOp> &validTransposes,
167  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
168  &transposeInfo);
169 
170  // Applies perms to the DenseElementsAttr.
171  // If it returns std::nullopt, it also triggers pass failure, since verifier
172  // guarantees from TOSA are not in place (and otherwise, if used elsewhere,
173  // it should fail).
174  // This is a basic API and may benefit from refactor into the core MLIR APIs.
175  std::optional<DenseElementsAttr>
176  transposeDenseAttribute(DenseElementsAttr input, ArrayRef<int32_t> perms);
177 };
178 
179 std::optional<DenseElementsAttr>
180 TosaReduceTransposes::transposeDenseAttribute(DenseElementsAttr input,
181  ArrayRef<int32_t> perms) {
182  RankedTensorType oldType = llvm::cast<RankedTensorType>(input.getType());
183  RankedTensorType newType =
184  RankedTensorType::get(applyTOSAPermutation(oldType.getShape(), perms),
185  oldType.getElementType());
186  size_t rank = oldType.getRank();
187 
188  // Asserted by TransposeOp verifier and TOSA disallowing tensor with dimension
189  // 0. If not in place, something is very wrong.
190  if (rank <= 0 || oldType.getNumElements() <= 0 || perms.size() != rank) {
191  signalPassFailure();
192  return std::nullopt;
193  }
194 
195  if (input.isSplat())
196  return input.reshape(newType);
197 
198  // The algorithm is approximately as follows:
199  // input: perms, input flat array, input tensor type
200  // (1/2) determine the strides of input/output if
201  // they were strided in row-major order. (3) adjust the strides for the
202  // input to be in the same order of indices as the output is written.
203  // (4) process dimension by dimension. example: perms 2, 0, 1; input
204  // 2x3x4; output 4x2x3 for i ... 4, j ... 2, k ... 3: output[i][j][k] =
205  // input[j][k][i] output[6i + 3j + k] = input[12j + 4k + i] and we adjust
206  // input strides to be as input[i + 12j + 4k] so we may process
207  // layer-by-layer.
208 
209  // Step 1/2: Strides for input. We ignore output since row-major and can just
210  // push_back.
211 
212  SmallVector<int64_t> originalInputStrides(rank);
213  originalInputStrides[rank - 1] = 1;
214  // index with int64_t to avoid overflow
215  for (int64_t i = rank - 2; i >= 0; i--)
216  originalInputStrides[i] =
217  originalInputStrides[i + 1] * oldType.getDimSize(i + 1);
218 
219  // Step 3: Transpose strides of input to be same indexing (i, j, k, ...) as
220  // output which is done in row-major order.
221 
222  SmallVector<int64_t> newInputStrides;
223  newInputStrides.reserve(rank);
224  for (int32_t v : perms)
225  newInputStrides.push_back(originalInputStrides[v]);
226 
227  // Step 4: Write out the transposed "flat array" dimension by dimension.
228 
229  auto inputArray = input.getValues<Attribute>();
230  SmallVector<std::pair<int64_t, int64_t>> boundsAndStrides;
231  for (size_t i = 0; i < rank; i++)
232  boundsAndStrides.push_back({newType.getDimSize(i), newInputStrides[i]});
233 
234  SmallVector<Attribute> resultArray;
235  resultArray.reserve(inputArray.size());
236 
237  std::function<void(int64_t,
238  SmallVector<std::pair<int64_t, int64_t>>::const_iterator)>
239  processTransposeDim = [&](auto accumulatedIndex, auto it) {
240  if (it == boundsAndStrides.end()) {
241  resultArray.push_back(inputArray[accumulatedIndex]);
242  return;
243  }
244 
245  for (int64_t i = 0; i < it->first; i++) {
246  int64_t j = accumulatedIndex + i * it->second;
247  processTransposeDim(j, it + 1);
248  }
249  };
250 
251  processTransposeDim(0, boundsAndStrides.begin());
252 
253  return DenseElementsAttr::get(newType, resultArray);
254 }
255 
256 // The SetVector should only contain ConstOp, ReshapeOp, TransposeOp
257 // as the sources of the data dependencies, and TosaElementWiseOperator
258 // after that, if the function returns true.
259 bool TosaReduceTransposes::collectFanIn(Operation *op,
260  SetVector<Operation *> &collected) {
261  // Can occur if defined through the parameter to a func.func.
262  if (!op)
263  return false;
264 
265  if (!llvm::isa_and_present<tosa::TosaDialect>(op->getDialect()))
266  return false;
267 
268  // Prevent extra work if already seen.
269  if (collected.contains(op))
270  return true;
271 
272  // Throw it out so later don't have to deal with this.
273  if (op->getNumResults() != 1 ||
274  !llvm::isa<RankedTensorType>(op->getResult(0).getType()))
275  return false;
276 
277  // We don't wish to traverse up a ReshapeOp, since generally we can't
278  // propagate a TransposeOp through it. TransposeOp, ReshapeOp, ConstOp
279  // will have no in-edges in the data dependency graph we construct for
280  // the downstream TransposeOp.
281  if (!llvm::isa<tosa::TransposeOp>(op) && !llvm::isa<tosa::ReshapeOp>(op) &&
282  !llvm::isa<tosa::ConstOp>(op)) {
283 
285  return false;
286 
287  for (Value operand : op->getOperands())
288  // If this is a problem in future, think about alternatives to recursion.
289  if (!collectFanIn(operand.getDefiningOp(), collected))
290  return false;
291  }
292 
293  // Insert in topological order.
294  collected.insert(op);
295 
296  return true;
297 }
298 
299 // Assuming that due to the verification of TransposeOp perms arrays are
300 // permutations of 0 - perms.size() - 1.
301 bool TosaReduceTransposes::areInvolutionTransposes(ArrayRef<int32_t> perms1,
302  ArrayRef<int32_t> perms2) {
303  if (perms1.size() != perms2.size())
304  return false;
305  int32_t n = perms1.size();
306  for (int32_t i = 0; i < n; i++)
307  if (perms2[perms1[i]] != i)
308  return false;
309  return true;
310 }
311 
312 // Primary overload for those with TosaElementwiseOperator trait.
313 // The other ones handle the case of the operations that occur at the
314 // roots of the data dependency graph (ConstOp, ReshapeOp, TransposeOp).
315 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
316  Operation *op, const DenseMap<Value, Value> &valuesMap,
317  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
318  if (op->getNumResults() != 1 ||
320  return std::nullopt;
321 
322  auto resultType = llvm::cast<RankedTensorType>(op->getResult(0).getType());
323  SmallVector<Value, 3> operands;
324  for (Value v : op->getOperands()) {
325  if (valuesMap.contains(v)) {
326  operands.push_back(valuesMap.at(v));
327  } else {
328  return std::nullopt;
329  }
330  }
331 
332  // Conceptually, we propagate the hoisted TransposeOp through
333  // these interveaning operations. For example,
334 
335  // %0 = tosa.clamp %input : (tensor<2x3xi32>) -> tensor<2x3xi32>
336  // %1 = tosa.transpose %0 {perms = [1, 0]} : (tensor<2x3xi32>) ->
337  // tensor<3x2xi32>
338 
339  // becomes:
340  // %0 = tosa.transpose %input {perms = [1, 0]} : (tensor<2x3xi32>) ->
341  // tensor<3x2xi32>
342  // %1 = tosa.clamp %0 : (tensor<3x2xi32>) -> tensor<3x2xi32>)
343 
344  // We construct this new tosa.clamp here, but it doesn't
345  // turn "live" until the transpose being hoisted through this chain
346  // is replaced with the proper value from the new chain.
347 
348  return rewriter
349  .create(op->getLoc(), op->getName().getIdentifier(), operands,
351  applyTOSAPermutation(resultType.getShape(), hoistedPerms),
352  resultType.getElementType()),
353  op->getAttrs())
354  ->getResult(0);
355 }
356 
357 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
358  TransposeOp transposeOp, const DenseMap<Value, Value> &valuesMap,
359  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
360  SmallVector<int32_t> perms;
361  if (failed(transposeOp.getConstantPerms(perms)) ||
362  !areInvolutionTransposes(hoistedPerms, perms))
363  return std::nullopt;
364  return transposeOp.getInput1();
365 }
366 
367 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
368  ReshapeOp reshapeOp, const DenseMap<Value, Value> &valuesMap,
369  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
370  auto reshapeOutput = reshapeOp.getOutput();
371  auto reshapeInputType =
372  llvm::dyn_cast<RankedTensorType>(reshapeOp.getInput1().getType());
373  auto reshapeInputShape = reshapeInputType.getShape();
374  // want reshape N -> 1x1x...x1xNx1x...x1x1
375  if (!reshapeInputType || reshapeInputShape.size() != 1)
376  return std::nullopt;
377  auto reshapeOutputType =
378  llvm::cast<RankedTensorType>(reshapeOutput.getType());
379 
380  // Instead of inserting a TransposeOp here, we check if we can fold it into
381  // the ReshapeOp. There is more complex cases where this is possible, and
382  // this check can be extended.
383 
384  // Checking if reshape is N -> 1x1x...x1xNx1x...x1x1
385  auto shape = reshapeOutputType.getShape();
386  size_t ones = llvm::count(shape, 1);
387  // N == 1 and N != 1
388  if (ones != shape.size() - 1 &&
389  !(ones == shape.size() && reshapeInputShape[0] == 1))
390  return std::nullopt;
391 
392  // Do not insert a TransposeOp, instead we fold the reshape and its attribute.
393  auto foldedReshape = rewriter.create<ReshapeOp>(
394  reshapeOp.getLoc(),
395  RankedTensorType::get(applyTOSAPermutation(shape, hoistedPerms),
396  reshapeOutputType.getElementType()),
397  reshapeOp.getInput1(),
398  rewriter.getDenseI64ArrayAttr(
399  applyTOSAPermutation(reshapeOp.getNewShape(), hoistedPerms)));
400  return foldedReshape->getResult(0);
401 }
402 
403 std::optional<Value> TosaReduceTransposes::buildMappedToValue(
404  ConstOp constOp, const DenseMap<Value, Value> &valuesMap,
405  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
406  auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(constOp.getValue());
407  if (!denseAttr)
408  return std::nullopt;
409  auto maybeNewDenseAttr = transposeDenseAttribute(denseAttr, hoistedPerms);
410  if (!maybeNewDenseAttr.has_value())
411  return std::nullopt;
412  auto newDenseAttr = maybeNewDenseAttr.value();
413  auto newConstOp = rewriter.create<ConstOp>(
414  constOp.getLoc(), newDenseAttr.getType(), newDenseAttr);
415  return newConstOp->getResult(0);
416 }
417 
418 bool TosaReduceTransposes::convertDependentOps(
419  SetVector<Operation *> &dependentOps, DenseMap<Value, Value> &valuesMap,
420  IRRewriter &rewriter, ArrayRef<int32_t> hoistedPerms) {
421 
422  for (Operation *op : dependentOps) {
423  if (!op || op->getNumResults() != 1)
424  return false;
425 
426  Value priorValue = op->getResult(0);
427 
428  // It's possible on a prior transposeOp we had the same dependency and
429  // already resolved it.
430  if (valuesMap.contains(priorValue))
431  continue;
432 
433  // Keep converted ops close to the original.
434  rewriter.setInsertionPointAfter(op);
435 
436  std::optional<Value> maybeValue =
438  .Case<TransposeOp, ReshapeOp, ConstOp>([&](auto transposeOp) {
439  return buildMappedToValue(transposeOp, valuesMap, rewriter,
440  hoistedPerms);
441  })
442  .Default([&](Operation *op) {
443  return buildMappedToValue(op, valuesMap, rewriter, hoistedPerms);
444  });
445 
446  if (!maybeValue.has_value())
447  return false;
448 
449  valuesMap[priorValue] = maybeValue.value();
450  }
451 
452  return true;
453 }
454 
455 bool TosaReduceTransposes::userNotContainedInValidTransposeDependencies(
456  Operation *user, std::set<TransposeOp> &validTransposes,
457  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
458  &transposeInfo) {
459  return llvm::none_of(
460  transposeInfo,
461  [&validTransposes,
462  user](const std::pair<TransposeOp, SetVector<Operation *>> &info) {
463  const auto &[transposeOp, dependentOps] = info;
464  return validTransposes.count(transposeOp) &&
465  dependentOps.contains(user);
466  });
467 }
468 
469 // Dependencies are valid for an operation if none of them occur outside
470 // of the proper fan-in cones of the hoisted TransposeOp with the same perms
471 // that we can replace. Described in more detail within.
472 bool TosaReduceTransposes::dependenciesAreValid(
473  ArrayRef<int32_t> perms, const SetVector<Operation *> &dependentOps,
474  std::set<TransposeOp> &validTransposes,
475  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
476  &transposeInfo) {
477  for (Operation *op : dependentOps) {
478 
479  // It's OK wherever ConstOp has uses -- in the worst case, we duplicate.
480  // This can be changed later if we find the memory impact is too high.
481  if (llvm::isa<ConstOp>(op))
482  continue;
483 
484  for (OpOperand &use : op->getUses()) {
485  // Want the uses to be (1) contained in the dependentOps of other
486  // validTransposes, or (2) to be directly used in a TransposeOp with the
487  // same perms. For (2) it means the fan-in is a subset of our
488  // dependentOps, so it is also a validTranspose that will eventually be
489  // replaced.
490  Operation *user = use.getOwner();
491  if (auto otherTranspose = llvm::dyn_cast<TransposeOp>(user)) {
492  SmallVector<int32_t> otherPerms;
493 
494  // Can later think about cases where transpose -> transpose
495  // or reshape -> transpose, where the transposes are not necessarily
496  // the same perms as the hoisted, if implementing a more general
497  // transform. These could be permitted.
498  if (failed(otherTranspose.getConstantPerms(otherPerms)) ||
499  !llvm::equal(perms, otherPerms))
500  return false;
501  } else if (userNotContainedInValidTransposeDependencies(
502  user, validTransposes, transposeInfo)) {
503  return false;
504  }
505  }
506  }
507 
508  return true;
509 }
510 
511 // Getting the set of TransposeOp that we can replace without causing
512 // the old fan-in cones of any TransposeOp to remain "live", i.e, -- not being
513 // dead code. This is done by iterating the set until convergence, since
514 // if you are used outside your own fan-in cone, it's possible to be used
515 // in another fan-in cone of a TransposeOp that is being replaced -- unless
516 // we find that that one has a usage outside of it too.
517 std::set<TransposeOp> TosaReduceTransposes::getGoodReplacements(
518  ArrayRef<int32_t> perms,
519  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>
520  &transposeInfo) {
521  // Initially, we assume they are all good to replace,
522  // and we whittle them down based on our criteria.
523  std::set<TransposeOp> ableToReplace;
524  for (const auto &[transposeOp, _] : transposeInfo)
525  ableToReplace.insert(transposeOp);
526 
527  bool gotRid;
528  do {
529  gotRid = false;
530  for (const auto &[transposeOp, dependentOps] : transposeInfo) {
531  // We don't care about it. Already invalidated.
532  if (!ableToReplace.count(transposeOp))
533  continue;
534 
535  // Check for validity.
536  if (!dependenciesAreValid(perms, dependentOps, ableToReplace,
537  transposeInfo)) {
538  ableToReplace.erase(transposeOp);
539  gotRid = true;
540  break;
541  }
542  }
543 
544  } while (gotRid);
545 
546  return ableToReplace;
547 }
548 
549 void TosaReduceTransposes::runOnOperation() {
550  // We want to operate only within a single block.
551  if (!getOperation().getRegion().hasOneBlock())
552  return;
553 
554  IRRewriter rewriter(&getContext());
555  // For each perms, maintain a mapping for converted ops, avoid duplication.
557  // For each perms, we keep track of which TransposeOp are eligible
558  // for replacement alongside their dependentOps.
560  std::vector<std::pair<TransposeOp, SetVector<Operation *>>>>
561  permsToTransposeInfo;
562 
563  // Necessary for lifetime, since DenseMap keeps a copy of the ArrayRef.
564  // Use SmallVector for perms (common-case is <= 4) but std::vector otherwise
565  // since no guarantee of smallness.
566  std::vector<SmallVector<int32_t>> collectedPerms;
567 
568  // This keeps track of the order across all eligible-for-replacement
569  // TransposeOp and their perms, a necessity for the final replacements.
570  std::stack<std::pair<TransposeOp, ArrayRef<int32_t>>> totalTransposeOrder;
571 
572  // We want to reserve the space up front, since SmallVector stores some data
573  // internally and the ArrayRef can reference that, which we don't want to get
574  // invalidated.
575  size_t expectedMaxPerms = 0;
576  getOperation().walk([&](TransposeOp) { expectedMaxPerms += 1; });
577  collectedPerms.reserve(expectedMaxPerms);
578 
579  getOperation().walk([&](TransposeOp transposeOp) {
580  SetVector<Operation *> dependentOps;
581  collectedPerms.emplace_back();
582  SmallVector<int32_t> &perms = collectedPerms.back();
583 
584  // Dynamic shapes are OK, but the incompatible ones will be rejected later.
585  auto input = transposeOp.getInput1();
586  auto output = transposeOp.getOutput();
587 
588  // However, we don't support unranked tensors.
589  if (!llvm::isa<RankedTensorType>(input.getType()) ||
590  !llvm::isa<RankedTensorType>(output.getType()))
591  return;
592 
593  // No transformation when transpose permutation non-constant.
594  if (failed(transposeOp.getConstantPerms(perms)))
595  return;
596 
597  // We let --canonicalize deal with identity transpose.
598  if (llvm::equal(llvm::seq<int32_t>(0, perms.size()), perms))
599  return;
600 
601  // Can fail if some set of basic invariants is not met that we want to
602  // perform our conversions.
603  if (!collectFanIn(input.getDefiningOp(), dependentOps))
604  return;
605 
606  // Want to associate valuesMap for already converted of the same perms,
607  // since it's possible multiple hoisted transposes w/ different perms
608  // converge on an op, which would result in different transformations.
609  DenseMap<Value, Value> &valuesMap = permsToValues[perms];
610 
611  // Attempt to perform the conversions and placements into IR
612  // without turning inserted code "live". Also fills out valuesMap.
613  // Fails if there is an intermediary we do not support.
614  if (!convertDependentOps(dependentOps, valuesMap, rewriter, perms))
615  // Some additional operations may have been inserted, but will be
616  // removed by dead code elimination.
617  return;
618 
619  // This should not happen. If it does -- it's unexpected,
620  // so we fail the pass.
621  if (!valuesMap.contains(input))
622  return signalPassFailure();
623 
624  // It's possible the types are not compatible (because of dynamic shapes),
625  // and in these cases, want to resolve dynamic shapes before running the
626  // pass.
627  if (output.getType() != valuesMap.at(input).getType())
628  return;
629 
630  auto &transposeInfo = permsToTransposeInfo[perms];
631 
632  // In general, we might also want to introduce "newDependentOps"
633  // if there are new usages that don't fall inside the original fan-ins
634  // (like the TransposeOp we insert for ReshapeOp),
635  // but in this case, that is specialized enough and overlaps
636  // with another direct-use TransposeOp case we need to cover anyway.
637  transposeInfo.push_back({transposeOp, dependentOps});
638 
639  // This is for the final replacement across all transposes.
640  totalTransposeOrder.push({transposeOp, perms});
641  });
642 
643  // We want to do a full fan-in analysis on a perms-level,
644  // since if we do it on a multi-perms level, and they share (due to a shared
645  // dependency on a Reshape) then we would also get duplicate ops.
646  // Const is special cased.
647  std::set<TransposeOp> ableToReplace;
648  for (auto &[perms, transposeInfo] : permsToTransposeInfo) {
649  // Gives us back replacements that would never result in any duplicate
650  // operations being inserted by us in the IR (i.e, our goal is only to
651  // remove transposes, and not create a "new chain" to do so, but replace
652  // the existing chains).
653  // Ideally, --canonicalize is run before this pass, since it helps this
654  // analysis by removing dead code to allow more potentially acceptable
655  // transformations.
656  auto goodReplacementsForPerms = getGoodReplacements(perms, transposeInfo);
657  ableToReplace.insert(goodReplacementsForPerms.begin(),
658  goodReplacementsForPerms.end());
659  }
660 
661  // We want to do replacement across all transposes
662  // in reverse order, due to invalidation of valuesMap mappings
663  // if we did it otherwise.
664  while (!totalTransposeOrder.empty()) {
665  auto [transposeOp, perms] = totalTransposeOrder.top();
666  totalTransposeOrder.pop();
667 
668  if (ableToReplace.count(transposeOp) == 0)
669  continue;
670 
671  auto &valuesMap = permsToValues[perms];
672  auto input = transposeOp.getInput1();
673 
674  // The purpose of this reverse iteration
675  // is to avoid valuesMap invalidation. If it happens,
676  // something is wrong.
677  if (!valuesMap.contains(input))
678  return signalPassFailure();
679 
680  rewriter.replaceOp(transposeOp, valuesMap.at(input));
681  }
682 
683  // We can remove all dead code by going in reverse.
684  // This is because we would remove usages before we
685  // see the users.
686  getOperation().walk<WalkOrder::PostOrder, ReverseIterator>(
687  [&](Operation *op) {
688  if (isOpTriviallyDead(op))
689  rewriter.eraseOp(op);
690  });
691 }
692 
693 } // namespace
static MLIRContext * getContext(OpFoldResult val)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
DenseI64ArrayAttr getDenseI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:207
An attribute that represents a reference to a dense vector or tensor object.
auto getValues() const
Return the held element values as a range of the given type.
bool isSplat() const
Returns true if this attribute corresponds to a splat, i.e.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
ShapedType getType() const
Return the type of this ElementsAttr, guaranteed to be a vector or tensor with static shape.
DenseElementsAttr reshape(ShapedType newType)
Return a new DenseElementsAttr that has the same data as the current attribute, but has been reshaped...
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:772
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
void setInsertionPointAfter(Operation *op)
Sets the insertion point to the node after the specified operation, which will cause subsequent inser...
Definition: Builders.h:421
This class represents an operand of an operation.
Definition: Value.h:267
This class indicates that an op is tosa-elementwise (permits broadcasting, unlike Elementwise trait).
Definition: TosaOps.h:91
StringAttr getIdentifier() const
Return the name of this operation as a StringAttr.
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:750
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition: Operation.h:220
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
ArrayRef< NamedAttribute > getAttrs()
Return all of the attributes on this operation.
Definition: Operation.h:512
OperationName getName()
The name of an operation is the key identifier for it.
Definition: Operation.h:119
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
use_range getUses()
Returns a range of all uses, which is useful for iterating over all uses.
Definition: Operation.h:847
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
SmallVector< T > applyTOSAPermutation(ArrayRef< T > input, ArrayRef< int32_t > perms)
Include the generated interface declarations.
bool isOpTriviallyDead(Operation *op)
Return true if the given operation is unused, and has no side effects on memory that prevent erasing.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
This iterator enumerates elements in "reverse" order.
Definition: Iterators.h:29
Eliminates variable at the specified position using Fourier-Motzkin variable elimination.