MLIR 22.0.0git
RewriteInsertsPass.cpp
Go to the documentation of this file.
1//===- RewriteInsertsPass.cpp - MLIR conversion pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a pass to rewrite sequential chains of
10// `spirv::CompositeInsert` operations into `spirv::CompositeConstruct`
11// operations.
12//
13//===----------------------------------------------------------------------===//
14
16
18#include "mlir/IR/Builders.h"
19
20namespace mlir {
21namespace spirv {
22#define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
23#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
24} // namespace spirv
25} // namespace mlir
26
27using namespace mlir;
28
29namespace {
30
31/// Replaces sequential chains of `spirv::CompositeInsertOp` operation into
32/// `spirv::CompositeConstructOp` operation if possible.
33class RewriteInsertsPass
34 : public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
35public:
36 void runOnOperation() override;
37
38private:
39 /// Collects a sequential insertion chain by the given
40 /// `spirv::CompositeInsertOp` operation, if the given operation is the last
41 /// in the chain.
42 LogicalResult
43 collectInsertionChain(spirv::CompositeInsertOp op,
44 SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
45};
46
47} // namespace
48
49void RewriteInsertsPass::runOnOperation() {
50 SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
51 getOperation().walk([this, &workList](spirv::CompositeInsertOp op) {
52 SmallVector<spirv::CompositeInsertOp, 4> insertions;
53 if (succeeded(collectInsertionChain(op, insertions)))
54 workList.push_back(insertions);
55 });
56
57 for (const auto &insertions : workList) {
58 auto lastCompositeInsertOp = insertions.back();
59 auto compositeType = lastCompositeInsertOp.getType();
60 auto location = lastCompositeInsertOp.getLoc();
61
62 SmallVector<Value, 4> operands;
63 // Collect inserted objects.
64 for (auto insertionOp : insertions)
65 operands.push_back(insertionOp.getObject());
66
67 OpBuilder builder(lastCompositeInsertOp);
68 auto compositeConstructOp = spirv::CompositeConstructOp::create(
69 builder, location, compositeType, operands);
70
71 lastCompositeInsertOp.replaceAllUsesWith(
72 compositeConstructOp->getResult(0));
73
74 // Erase ops.
75 for (auto insertOp : llvm::reverse(insertions)) {
76 auto *op = insertOp.getOperation();
77 if (op->use_empty())
78 insertOp.erase();
79 }
80 }
81}
82
83LogicalResult RewriteInsertsPass::collectInsertionChain(
84 spirv::CompositeInsertOp op,
85 SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
86 if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
87 return failure();
88
89 auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
90 // TODO: handle nested composite object.
91 if (indicesArrayAttr.size() == 1) {
92 auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
93 .getNumElements();
94
95 auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
96 // Need a last index to collect a sequential chain.
97 if (index + 1 != numElements)
98 return failure();
99
100 insertions.resize(numElements);
101 while (true) {
102 insertions[index] = op;
103
104 if (index == 0)
105 return success();
106
107 op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
108 if (!op)
109 return failure();
110
111 --index;
112 indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
113 if ((indicesArrayAttr.size() != 1) ||
114 (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
115 return failure();
116 }
117 }
118 return failure();
119}
return success()
Include the generated interface declarations.