22#define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
23#include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
33class RewriteInsertsPass
36 void runOnOperation()
override;
43 collectInsertionChain(spirv::CompositeInsertOp op,
44 SmallVectorImpl<spirv::CompositeInsertOp> &insertions);
49void RewriteInsertsPass::runOnOperation() {
50 SmallVector<SmallVector<spirv::CompositeInsertOp, 4>, 4> workList;
51 getOperation().walk([
this, &workList](spirv::CompositeInsertOp op) {
52 SmallVector<spirv::CompositeInsertOp, 4> insertions;
53 if (succeeded(collectInsertionChain(op, insertions)))
54 workList.push_back(insertions);
57 for (
const auto &insertions : workList) {
58 auto lastCompositeInsertOp = insertions.back();
59 auto compositeType = lastCompositeInsertOp.getType();
60 auto location = lastCompositeInsertOp.getLoc();
62 SmallVector<Value, 4> operands;
64 for (
auto insertionOp : insertions)
65 operands.push_back(insertionOp.getObject());
67 OpBuilder builder(lastCompositeInsertOp);
68 auto compositeConstructOp = spirv::CompositeConstructOp::create(
69 builder, location, compositeType, operands);
71 lastCompositeInsertOp.replaceAllUsesWith(
72 compositeConstructOp->getResult(0));
75 for (
auto insertOp : llvm::reverse(insertions)) {
76 auto *op = insertOp.getOperation();
83LogicalResult RewriteInsertsPass::collectInsertionChain(
84 spirv::CompositeInsertOp op,
85 SmallVectorImpl<spirv::CompositeInsertOp> &insertions) {
86 if (isa<spirv::CooperativeMatrixType>(op.getComposite().getType()))
89 auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
91 if (indicesArrayAttr.size() == 1) {
92 auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
95 auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
97 if (index + 1 != numElements)
100 insertions.resize(numElements);
102 insertions[index] = op;
107 op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
112 indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
113 if ((indicesArrayAttr.size() != 1) ||
114 (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
Include the generated interface declarations.