MLIR  16.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/IR/BuiltinOps.h"
17 #include "mlir/IR/PatternMatch.h"
18 
19 using namespace mlir;
20 using namespace mlir::x86vector;
21 
22 /// Extracts the "main" vector element type from the given X86Vector operation.
23 template <typename OpTy>
24 static Type getSrcVectorElementType(OpTy op) {
25  return op.getSrc().getType().template cast<VectorType>().getElementType();
26 }
27 template <>
28 Type getSrcVectorElementType(Vp2IntersectOp op) {
29  return op.getA().getType().template cast<VectorType>().getElementType();
30 }
31 
32 namespace {
33 
34 /// Base conversion for AVX512 ops that can be lowered to one of the two
35 /// intrinsics based on the bitwidth of their "main" vector element type. This
36 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
37 /// results of multi-result intrinsic ops.
38 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
39 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
40  explicit LowerToIntrinsic(LLVMTypeConverter &converter)
41  : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
42 
43  LLVMTypeConverter &getTypeConverter() const {
44  return *static_cast<LLVMTypeConverter *>(
46  }
47 
49  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  Type elementType = getSrcVectorElementType<OpTy>(op);
52  unsigned bitwidth = elementType.getIntOrFloatBitWidth();
53  if (bitwidth == 32)
54  return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
55  adaptor.getOperands(),
56  getTypeConverter(), rewriter);
57  if (bitwidth == 64)
58  return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
59  adaptor.getOperands(),
60  getTypeConverter(), rewriter);
61  return rewriter.notifyMatchFailure(
62  op, "expected 'src' to be either f32 or f64");
63  }
64 };
65 
66 struct MaskCompressOpConversion
67  : public ConvertOpToLLVMPattern<MaskCompressOp> {
69 
71  matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73  auto opType = adaptor.getA().getType();
74 
75  Value src;
76  if (op.getSrc()) {
77  src = adaptor.getSrc();
78  } else if (op.getConstantSrc()) {
79  src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
80  op.getConstantSrcAttr());
81  } else {
82  Attribute zeroAttr = rewriter.getZeroAttr(opType);
83  src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
84  }
85 
86  rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.getA(),
87  src, adaptor.getK());
88 
89  return success();
90  }
91 };
92 
93 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
95 
97  matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
98  ConversionPatternRewriter &rewriter) const override {
99  auto opType = adaptor.getA().getType();
100  rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.getA());
101  return success();
102  }
103 };
104 
105 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
107 
109  matchAndRewrite(DotOp op, OpAdaptor adaptor,
110  ConversionPatternRewriter &rewriter) const override {
111  auto opType = adaptor.getA().getType();
112  Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
113  // Dot product of all elements, broadcasted to all elements.
114  auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
115  Value scale =
116  rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
117  rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.getA(),
118  adaptor.getB(), scale);
119  return success();
120  }
121 };
122 
123 /// An entry associating the "main" AVX512 op with its instantiations for
124 /// vectors of 32-bit and 64-bit elements.
125 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
126 struct RegEntry {
127  using MainOp = OpTy;
128  using Intr32Op = Intr32OpTy;
129  using Intr64Op = Intr64OpTy;
130 };
131 
132 /// A container for op association entries facilitating the configuration of
133 /// dialect conversion.
134 template <typename... Args>
135 struct RegistryImpl {
136  /// Registers the patterns specializing the "main" op to one of the
137  /// "intrinsic" ops depending on elemental type.
138  static void registerPatterns(LLVMTypeConverter &converter,
139  RewritePatternSet &patterns) {
140  patterns
141  .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
142  typename Args::Intr64Op>...>(converter);
143  }
144 
145  /// Configures the conversion target to lower out "main" ops.
146  static void configureTarget(LLVMConversionTarget &target) {
147  target.addIllegalOp<typename Args::MainOp...>();
148  target.addLegalOp<typename Args::Intr32Op...>();
149  target.addLegalOp<typename Args::Intr64Op...>();
150  }
151 };
152 
153 using Registry = RegistryImpl<
154  RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
155  RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
156  RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
157 
158 } // namespace
159 
160 /// Populate the given list with patterns that convert from X86Vector to LLVM.
162  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
163  Registry::registerPatterns(converter, patterns);
164  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
165  converter);
166 }
167 
169  LLVMConversionTarget &target) {
170  Registry::configureTarget(target);
171  target.addLegalOp<MaskCompressIntrOp>();
172  target.addIllegalOp<MaskCompressOp>();
173  target.addLegalOp<RsqrtIntrOp>();
174  target.addIllegalOp<RsqrtOp>();
175  target.addLegalOp<DotIntrOp>();
176  target.addIllegalOp<DotOp>();
177 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:133
void addLegalOp(OperationName op)
Register the given operations as legal.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:288
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
static Type getSrcVectorElementType(OpTy op)
Extracts the "main" vector element type from the given X86Vector operation.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:685
Derived class that automatically populates legalization information for different LLVM ops...
void addIllegalOp(OperationName op)
Register the given operation as illegal, i.e.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:404
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:190
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
LogicalResult notifyMatchFailure(Location loc, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:85
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:451
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:30
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132