23 #define GEN_PASS_DEF_SPIRVREWRITEINSERTSPASS
24 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
34 class RewriteInsertsPass
35 :
public spirv::impl::SPIRVRewriteInsertsPassBase<RewriteInsertsPass> {
37 void runOnOperation()
override;
44 collectInsertionChain(spirv::CompositeInsertOp op,
50 void RewriteInsertsPass::runOnOperation() {
52 getOperation().walk([
this, &workList](spirv::CompositeInsertOp op) {
54 if (succeeded(collectInsertionChain(op, insertions)))
55 workList.push_back(insertions);
58 for (
const auto &insertions : workList) {
59 auto lastCompositeInsertOp = insertions.back();
60 auto compositeType = lastCompositeInsertOp.getType();
61 auto location = lastCompositeInsertOp.getLoc();
65 for (
auto insertionOp : insertions)
66 operands.push_back(insertionOp.getObject());
69 auto compositeConstructOp = builder.create<spirv::CompositeConstructOp>(
70 location, compositeType, operands);
72 lastCompositeInsertOp.replaceAllUsesWith(
73 compositeConstructOp->getResult(0));
76 for (
auto insertOp : llvm::reverse(insertions)) {
77 auto *op = insertOp.getOperation();
84 LogicalResult RewriteInsertsPass::collectInsertionChain(
85 spirv::CompositeInsertOp op,
87 auto indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
89 if (indicesArrayAttr.size() == 1) {
90 auto numElements = cast<spirv::CompositeType>(op.getComposite().getType())
93 auto index = cast<IntegerAttr>(indicesArrayAttr[0]).getInt();
95 if (index + 1 != numElements)
98 insertions.resize(numElements);
100 insertions[index] = op;
105 op = op.getComposite().getDefiningOp<spirv::CompositeInsertOp>();
110 indicesArrayAttr = cast<ArrayAttr>(op.getIndices());
111 if ((indicesArrayAttr.size() != 1) ||
112 (cast<IntegerAttr>(indicesArrayAttr[0]).getInt() != index))
This class helps build Operations.
Include the generated interface declarations.