MLIR  22.0.0git
ConvertToReplicatedConstantCompositePass.cpp
Go to the documentation of this file.
1 //===- ConvertToReplicatedConstantCompositePass.cpp -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert a splat composite spirv.Constant and
10 // spirv.SpecConstantComposite to spirv.EXT.ConstantCompositeReplicate and
11 // spirv.EXT.SpecConstantCompositeReplicate respectively.
12 //
13 //===----------------------------------------------------------------------===//
14 
18 
19 namespace mlir::spirv {
20 #define GEN_PASS_DEF_SPIRVREPLICATEDCONSTANTCOMPOSITEPASS
21 #include "mlir/Dialect/SPIRV/Transforms/Passes.h.inc"
22 
23 namespace {
24 
25 static Type getArrayElemType(Attribute attr) {
26  if (auto typedAttr = dyn_cast<TypedAttr>(attr)) {
27  return typedAttr.getType();
28  }
29 
30  if (auto arrayAttr = dyn_cast<ArrayAttr>(attr)) {
31  return ArrayType::get(getArrayElemType(arrayAttr[0]), arrayAttr.size());
32  }
33 
34  return nullptr;
35 }
36 
37 static std::pair<Attribute, uint32_t>
38 getSplatAttrAndNumElements(Attribute valueAttr, Type valueType) {
39  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(valueType);
40  if (!compositeType)
41  return {nullptr, 1};
42 
43  if (auto splatAttr = dyn_cast<SplatElementsAttr>(valueAttr)) {
44  return {splatAttr.getSplatValue<Attribute>(), splatAttr.size()};
45  }
46 
47  if (auto arrayAttr = dyn_cast<ArrayAttr>(valueAttr)) {
48  if (llvm::all_equal(arrayAttr)) {
49  Attribute attr = arrayAttr[0];
50  uint32_t numElements = arrayAttr.size();
51 
52  // Find the inner-most splat value for array of composites
53  auto [newAttr, newNumElements] =
54  getSplatAttrAndNumElements(attr, getArrayElemType(attr));
55  if (newAttr) {
56  attr = newAttr;
57  numElements *= newNumElements;
58  }
59  return {attr, numElements};
60  }
61  }
62 
63  return {nullptr, 1};
64 }
65 
66 struct ConstantOpConversion final : OpRewritePattern<spirv::ConstantOp> {
68 
69  LogicalResult matchAndRewrite(spirv::ConstantOp op,
70  PatternRewriter &rewriter) const override {
71  auto [attr, numElements] =
72  getSplatAttrAndNumElements(op.getValue(), op.getType());
73  if (!attr)
74  return rewriter.notifyMatchFailure(op, "composite is not splat");
75 
76  if (numElements == 1)
77  return rewriter.notifyMatchFailure(op,
78  "composite has only one constituent");
79 
80  rewriter.replaceOpWithNewOp<spirv::EXTConstantCompositeReplicateOp>(
81  op, op.getType(), attr);
82  return success();
83  }
84 };
85 
86 struct SpecConstantCompositeOpConversion final
87  : OpRewritePattern<spirv::SpecConstantCompositeOp> {
89 
90  LogicalResult matchAndRewrite(spirv::SpecConstantCompositeOp op,
91  PatternRewriter &rewriter) const override {
92  auto compositeType = dyn_cast_or_null<spirv::CompositeType>(op.getType());
93  if (!compositeType)
94  return rewriter.notifyMatchFailure(op, "not a composite constant");
95 
96  ArrayAttr constituents = op.getConstituents();
97  if (constituents.size() == 1)
98  return rewriter.notifyMatchFailure(op,
99  "composite has only one consituent");
100 
101  if (!llvm::all_equal(constituents))
102  return rewriter.notifyMatchFailure(op, "composite is not splat");
103 
104  auto splatConstituent = dyn_cast<FlatSymbolRefAttr>(constituents[0]);
105  if (!splatConstituent)
106  return rewriter.notifyMatchFailure(
107  op, "expected flat symbol reference for splat constituent");
108 
109  rewriter.replaceOpWithNewOp<spirv::EXTSpecConstantCompositeReplicateOp>(
110  op, TypeAttr::get(op.getType()), op.getSymNameAttr(), splatConstituent);
111 
112  return success();
113  }
114 };
115 
116 struct ConvertToReplicatedConstantCompositePass final
117  : spirv::impl::SPIRVReplicatedConstantCompositePassBase<
118  ConvertToReplicatedConstantCompositePass> {
119  void runOnOperation() override {
120  MLIRContext *context = &getContext();
121  RewritePatternSet patterns(context);
122  patterns.add<ConstantOpConversion, SpecConstantCompositeOpConversion>(
123  context);
124  walkAndApplyPatterns(getOperation(), std::move(patterns));
125  }
126 };
127 
128 } // namespace
129 } // namespace mlir::spirv
static MLIRContext * getContext(OpFoldResult val)
static ArrayType get(Type elementType, unsigned elementCount)
Definition: SPIRVTypes.cpp:50
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, RewriterBase::Listener *listener=nullptr)
A fast walk-based pattern rewrite driver.
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:319