MLIR  22.0.0git
RewriteInsertsPass.cpp
Go to the documentation of this file.
1 //===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to rewrite sequential chains of
10 // `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
11 // operations.
12 //
13 //===----------------------------------------------------------------------===//
14 
16 
18 #include "mlir/IR/Builders.h"
19 
20 namespace mlir {
21 namespace spirv {
22 #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
23 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
24 } // namespace spirv
25 } // namespace mlir
26 
27 using namespace mlir;
28 
29 namespace {
30 
31 /// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
32 /// `spirv::CompositeConstructOp` operation if possible.
33 class RewriteInsertsPass
34  : public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
35 public:
36  void runOnOperation() override;
37 
38 private:
39  /// Collects a sequential insertion chain by the given
40  /// `spirv::CompositeInsertOp` operation, if the given operation is the last
41  /// in the chain.
42  LogicalResult
43  collectInsertionChain(spirv::CompositeInsertOp op,
45 };
46 
47 } // namespace
48 
49 void RewriteInsertsPass::runOnOperation() {
51  getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
53  if (succeeded(collectInsertionChain(op, insertions)))
54  workList.push_back(insertions);
55  });
56 
57  for (const auto &insertions : workList) {
58  auto lastCompositeInsertOp = insertions.back();
59  auto compositeType = lastCompositeInsertOp.getType();
60  auto location = lastCompositeInsertOp.getLoc();
61 
62  SmallVector<Value, 4> operands;
63  // Collect inserted objects.
64  for (auto insertionOp : insertions)
65  operands.push_back(insertionOp.getObject());
66 
67  OpBuilder builder(lastCompositeInsertOp);
68  auto compositeConstructOp = spirv::CompositeConstructOp::create(
69  builder, location, compositeType, operands);
70 
71  lastCompositeInsertOp.replaceAllUsesWith(
72  compositeConstructOp->getResult(0));
73 
74  // Erase ops.
75  for (auto insertOp : llvm::reverse(insertions)) {
76  auto *op = insertOp.getOperation();
77  if (op->use_empty())
78  insertOp.erase();
79  }
80  }
81 }
82 
83 LogicalResult RewriteInsertsPass::collectInsertionChain(
84  spirv::CompositeInsertOp op,
86  if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
87  return failure();
88 
89  auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
90  // TODO: handle nested composite object.
91  if (indicesArrayAttr.size() == 1) {
92  auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
93  .getNumElements();
94 
95  auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
96  // Need a last index to collect a sequential chain.
97  if (index + 1 != numElements)
98  return failure();
99 
100  insertions.resize(numElements);
101  while (true) {
102  insertions[index] = op;
103 
104  if (index == 0)
105  return success();
106 
107  op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
108  if (!op)
109  return failure();
110 
111  --index;
112  indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
113  if ((indicesArrayAttr.size() != 1) ||
114  (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
115  return failure();
116  }
117  }
118  return failure();
119 }
This class helps build Operations.
Definition: Builders.h:205
Include the generated interface declarations.