MLIR  20.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::x86vector;
21 
22 /// Extracts the "main" vector element type from the given X86Vector operation.
23 template <typename OpTy>
24 static Type getSrcVectorElementType(OpTy op) {
25  return cast<VectorType>(op.getSrc().getType()).getElementType();
26 }
27 template <>
28 Type getSrcVectorElementType(Vp2IntersectOp op) {
29  return cast<VectorType>(op.getA().getType()).getElementType();
30 }
31 
32 namespace {
33 
34 /// Base conversion for AVX512 ops that can be lowered to one of the two
35 /// intrinsics based on the bitwidth of their "main" vector element type. This
36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37 /// results of multi-result intrinsic ops.
38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40  explicit LowerToIntrinsic(const LLVMTypeConverter &converter)
41  : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42 
43  const LLVMTypeConverter &getTypeConverter() const {
44  return *static_cast<const LLVMTypeConverter *>(
46  }
47 
48  LogicalResult
49  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  Type elementType = getSrcVectorElementType<OpTy>(op);
52  unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53  if (bitwidth == 32)
55  op, Intr32OpTy::getOperationName(), adaptor.getOperands(),
56  op->getAttrs(), getTypeConverter(), rewriter);
57  if (bitwidth == 64)
59  op, Intr64OpTy::getOperationName(), adaptor.getOperands(),
60  op->getAttrs(), getTypeConverter(), rewriter);
61  return rewriter.notifyMatchFailure(
62  op, "expected 'src' to be either f32 or f64");
63  }
64 };
65 
66 struct MaskCompressOpConversion
67  : public ConvertOpToLLVMPattern<MaskCompressOp> {
69 
70  LogicalResult
71  matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73  auto opType = adaptor.getA().getType();
74 
75  Value src;
76  if (op.getSrc()) {
77  src = adaptor.getSrc();
78  } else if (op.getConstantSrc()) {
79  src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80  op.getConstantSrcAttr());
81  } else {
82  auto zeroAttr = rewriter.getZeroAttr(opType);
83  src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
84  }
85 
86  rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87  src, adaptor.getK());
88 
89  return success();
90  }
91 };
92 
93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
95 
96  LogicalResult
97  matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  auto opType = adaptor.getA().getType();
100  rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
101  return success();
102  }
103 };
104 
105 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
107 
108  LogicalResult
109  matchAndRewrite(DotOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  auto opType = adaptor.getA().getType();
112  Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
113  // Dot product of all elements, broadcasted to all elements.
114  auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
115  Value scale =
116  rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
117  rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
118  adaptor.getB(), scale);
119  return success();
120  }
121 };
122 
123 /// An entry associating the "main" AVX512 op with its instantiations for
124 /// vectors of 32-bit and 64-bit elements.
125 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
126 struct RegEntry {
127  using MainOp = OpTy;
128  using Intr32Op = Intr32OpTy;
129  using Intr64Op = Intr64OpTy;
130 };
131 
132 /// A container for op association entries facilitating the configuration of
133 /// dialect conversion.
134 template <typename... Args>
135 struct RegistryImpl {
136  /// Registers the patterns specializing the "main" op to one of the
137  /// "intrinsic" ops depending on elemental type.
138  static void registerPatterns(const LLVMTypeConverter &converter,
139  RewritePatternSet &patterns) {
140  patterns
141  .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
142  typename Args::Intr64Op>...>(converter);
143  }
144 
145  /// Configures the conversion target to lower out "main" ops.
146  static void configureTarget(LLVMConversionTarget &target) {
147  target.addIllegalOp<typename Args::MainOp...>();
148  target.addLegalOp<typename Args::Intr32Op...>();
149  target.addLegalOp<typename Args::Intr64Op...>();
150  }
151 };
152 
153 using Registry = RegistryImpl<
154  RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
155  RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
156  RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
157 
158 } // namespace
159 
160 /// Populate the given list with patterns that convert from X86Vector to LLVM.
162  const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
163  Registry::registerPatterns(converter, patterns);
164  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165  converter);
166 }
167 
169  LLVMConversionTarget &target) {
170  Registry::configureTarget(target);
171  target.addLegalOp<MaskCompressIntrOp>();
172  target.addIllegalOp<MaskCompressOp>();
173  target.addLegalOp<RsqrtIntrOp>();
174  target.addIllegalOp<RsqrtOp>();
175  target.addLegalOp<DotIntrOp>();
176  target.addIllegalOp<DotOp>();
177 }
static MLIRContext * getContext(OpFoldResult val)
static Type getSrcVectorElementType(OpTy op)
Extracts the "main" vector element type from the given X86Vector operation.
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:364
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:261
This class implements a pattern rewriter for use with ConversionPatterns.
const TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
void addLegalOp(OperationName op)
Register the given operations as legal.
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:143
Derived class that automatically populates legalization information for different LLVM ops.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:853
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:724
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:542
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:133
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:338
Include the generated interface declarations.
void populateX86VectorLegalizeForLLVMExportPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.