MLIR  20.0.0git
ArmNeon2dToIntr.cpp
Go to the documentation of this file.
1 //===- ArmNeon2dToIntr.cpp - convert Arm Neon 2d ops to intrinsics --------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Pass/PassRegistry.h"
17 
18 namespace mlir {
19 #define GEN_PASS_DEF_CONVERTARMNEON2DTOINTR
20 #include "mlir/Conversion/Passes.h.inc"
21 } // namespace mlir
22 
23 using namespace mlir;
24 using namespace mlir::arm_neon;
25 
26 namespace {
27 
28 class Sdot2dLoweringPattern : public OpRewritePattern<Sdot2dOp> {
29 public:
31 
32  /// Convert to 1-dimensional vector type to match the requirements of
33  /// arm.neon.intr.sdot
34  LogicalResult matchAndRewrite(Sdot2dOp op,
35  PatternRewriter &rewriter) const override {
36  Type elemType = cast<VectorType>(op.getB().getType()).getElementType();
37  int length = cast<VectorType>(op.getB().getType()).getShape()[0] *
38  Sdot2dOp::kReductionSize;
39  VectorType flattenedVectorType = VectorType::get({length}, elemType);
40  Value b2d = op.getB();
41  Value c2d = op.getC();
42  Location loc = op.getLoc();
43  Value b1d =
44  rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, b2d);
45  Value c1d =
46  rewriter.create<vector::ShapeCastOp>(loc, flattenedVectorType, c2d);
47  Value newOp = rewriter.create<SdotOp>(loc, op.getRes().getType(), op.getA(),
48  b1d, c1d);
49  rewriter.replaceOp(op, {newOp});
50  return success();
51  }
52 };
53 
54 class ConvertArmNeon2dToIntr
55  : public impl::ConvertArmNeon2dToIntrBase<ConvertArmNeon2dToIntr> {
56  void runOnOperation() override {
57  auto *context = &getContext();
58 
59  RewritePatternSet patterns(context);
61 
62  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
63  return signalPassFailure();
64  }
65 };
66 
67 } // namespace
68 
70  patterns.add<Sdot2dLoweringPattern>(patterns.getContext());
71 }
72 
73 std::unique_ptr<Pass> mlir::createConvertArmNeon2dToIntrPass() {
74  return std::make_unique<ConvertArmNeon2dToIntr>();
75 }
static MLIRContext * getContext(OpFoldResult val)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:791
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateConvertArmNeon2dToIntrPatterns(RewritePatternSet &patterns)
Populates patterns for the lowering of Arm NEON 2D ops to intrinsics.
std::unique_ptr< Pass > createConvertArmNeon2dToIntrPass()
Creates a pass to lower Arm NEON 2D ops to intrinsics, i.e.
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
OpRewritePattern(MLIRContext *context, PatternBenefit benefit=1, ArrayRef< StringRef > generatedNames={})
Patterns must specify the root operation name they match against, and can also specify the benefit of...
Definition: PatternMatch.h:362