MLIR  14.0.0git
LegalizeForLLVMExport.cpp
Go to the documentation of this file.
1 //===- LegalizeForLLVMExport.cpp - Prepare X86Vector for LLVM translation -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/IR/PatternMatch.h"
19 
20 using namespace mlir;
21 using namespace mlir::x86vector;
22 
23 /// Extracts the "main" vector element type from the given X86Vector operation.
24 template <typename OpTy>
25 static Type getSrcVectorElementType(OpTy op) {
26  return op.src().getType().template cast<VectorType>().getElementType();
27 }
28 template <>
29 Type getSrcVectorElementType(Vp2IntersectOp op) {
30  return op.a().getType().template cast<VectorType>().getElementType();
31 }
32 
33 namespace {
34 
35 /// Base conversion for AVX512 ops that can be lowered to one of the two
36 /// intrinsics based on the bitwidth of their "main" vector element type. This
37 /// relies on the to-LLVM-dialect conversion helpers to correctly pack the
38 /// results of multi-result intrinsic ops.
39 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
40 struct LowerToIntrinsic : public OpConversionPattern<OpTy> {
41  explicit LowerToIntrinsic(LLVMTypeConverter &converter)
42  : OpConversionPattern<OpTy>(converter, &converter.getContext()) {}
43 
44  LLVMTypeConverter &getTypeConverter() const {
45  return *static_cast<LLVMTypeConverter *>(
47  }
48 
50  matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor,
51  ConversionPatternRewriter &rewriter) const override {
52  Type elementType = getSrcVectorElementType<OpTy>(op);
53  unsigned bitwidth = elementType.getIntOrFloatBitWidth();
54  if (bitwidth == 32)
55  return LLVM::detail::oneToOneRewrite(op, Intr32OpTy::getOperationName(),
56  adaptor.getOperands(),
57  getTypeConverter(), rewriter);
58  if (bitwidth == 64)
59  return LLVM::detail::oneToOneRewrite(op, Intr64OpTy::getOperationName(),
60  adaptor.getOperands(),
61  getTypeConverter(), rewriter);
62  return rewriter.notifyMatchFailure(
63  op, "expected 'src' to be either f32 or f64");
64  }
65 };
66 
67 struct MaskCompressOpConversion
68  : public ConvertOpToLLVMPattern<MaskCompressOp> {
70 
72  matchAndRewrite(MaskCompressOp op, OpAdaptor adaptor,
73  ConversionPatternRewriter &rewriter) const override {
74  auto opType = adaptor.a().getType();
75 
76  Value src;
77  if (op.src()) {
78  src = adaptor.src();
79  } else if (op.constant_src()) {
80  src = rewriter.create<arith::ConstantOp>(op.getLoc(), opType,
81  op.constant_srcAttr());
82  } else {
83  Attribute zeroAttr = rewriter.getZeroAttr(opType);
84  src = rewriter.create<arith::ConstantOp>(op->getLoc(), opType, zeroAttr);
85  }
86 
87  rewriter.replaceOpWithNewOp<MaskCompressIntrOp>(op, opType, adaptor.a(),
88  src, adaptor.k());
89 
90  return success();
91  }
92 };
93 
94 struct RsqrtOpConversion : public ConvertOpToLLVMPattern<RsqrtOp> {
96 
98  matchAndRewrite(RsqrtOp op, OpAdaptor adaptor,
99  ConversionPatternRewriter &rewriter) const override {
100  auto opType = adaptor.a().getType();
101  rewriter.replaceOpWithNewOp<RsqrtIntrOp>(op, opType, adaptor.a());
102  return success();
103  }
104 };
105 
106 struct DotOpConversion : public ConvertOpToLLVMPattern<DotOp> {
108 
110  matchAndRewrite(DotOp op, OpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const override {
112  auto opType = adaptor.a().getType();
113  Type llvmIntType = IntegerType::get(&getTypeConverter()->getContext(), 8);
114  // Dot product of all elements, broadcasted to all elements.
115  auto attr = rewriter.getI8IntegerAttr(static_cast<int8_t>(0xff));
116  Value scale =
117  rewriter.create<LLVM::ConstantOp>(op.getLoc(), llvmIntType, attr);
118  rewriter.replaceOpWithNewOp<DotIntrOp>(op, opType, adaptor.a(), adaptor.b(),
119  scale);
120  return success();
121  }
122 };
123 
124 /// An entry associating the "main" AVX512 op with its instantiations for
125 /// vectors of 32-bit and 64-bit elements.
126 template <typename OpTy, typename Intr32OpTy, typename Intr64OpTy>
127 struct RegEntry {
128  using MainOp = OpTy;
129  using Intr32Op = Intr32OpTy;
130  using Intr64Op = Intr64OpTy;
131 };
132 
133 /// A container for op association entries facilitating the configuration of
134 /// dialect conversion.
135 template <typename... Args>
136 struct RegistryImpl {
137  /// Registers the patterns specializing the "main" op to one of the
138  /// "intrinsic" ops depending on elemental type.
139  static void registerPatterns(LLVMTypeConverter &converter,
140  RewritePatternSet &patterns) {
141  patterns
142  .add<LowerToIntrinsic<typename Args::MainOp, typename Args::Intr32Op,
143  typename Args::Intr64Op>...>(converter);
144  }
145 
146  /// Configures the conversion target to lower out "main" ops.
147  static void configureTarget(LLVMConversionTarget &target) {
148  target.addIllegalOp<typename Args::MainOp...>();
149  target.addLegalOp<typename Args::Intr32Op...>();
150  target.addLegalOp<typename Args::Intr64Op...>();
151  }
152 };
153 
154 using Registry = RegistryImpl<
155  RegEntry<MaskRndScaleOp, MaskRndScalePSIntrOp, MaskRndScalePDIntrOp>,
156  RegEntry<MaskScaleFOp, MaskScaleFPSIntrOp, MaskScaleFPDIntrOp>,
157  RegEntry<Vp2IntersectOp, Vp2IntersectDIntrOp, Vp2IntersectQIntrOp>>;
158 
159 } // namespace
160 
161 /// Populate the given list with patterns that convert from X86Vector to LLVM.
163  LLVMTypeConverter &converter, RewritePatternSet &patterns) {
164  Registry::registerPatterns(converter, patterns);
165  patterns.add<MaskCompressOpConversion, RsqrtOpConversion, DotOpConversion>(
166  converter);
167 }
168 
170  LLVMConversionTarget &target) {
171  Registry::configureTarget(target);
172  target.addLegalOp<MaskCompressIntrOp>();
173  target.addIllegalOp<MaskCompressOp>();
174  target.addLegalOp<RsqrtIntrOp>();
175  target.addIllegalOp<RsqrtOp>();
176  target.addLegalOp<DotIntrOp>();
177  target.addIllegalOp<DotOp>();
178 }
Include the generated interface declarations.
Utility class for operation conversions targeting the LLVM dialect that match exactly one source oper...
Definition: Pattern.h:132
OpTy create(Location location, Args &&...args)
Create an operation of specific op type at the current insertion point.
Definition: Builders.h:430
Attribute getZeroAttr(Type type)
Definition: Builders.cpp:264
TypeConverter * getTypeConverter() const
Return the type converter held by this pattern, or nullptr if the pattern does not require type conve...
static Type getSrcVectorElementType(OpTy op)
Extracts the "main" vector element type from the given X86Vector operation.
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity...
Definition: SPIRVOps.cpp:639
Derived class that automatically populates legalization information for different LLVM ops...
LogicalResult notifyMatchFailure(Operation *op, function_ref< void(Diagnostic &)> reasonCallback) override
PatternRewriter hook for notifying match failure reasons.
void configureX86VectorLegalizeForExportTarget(LLVMConversionTarget &target)
Configure the target to support lowering X86Vector ops to ops that map to LLVM intrinsics.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
IntegerAttr getI8IntegerAttr(int8_t value)
Definition: Builders.cpp:166
Attributes are known-constant values of operations.
Definition: Attributes.h:24
void populateX86VectorLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
Collect a set of patterns to lower X86Vector ops to ops that map to LLVM intrinsics.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:72
void addIllegalOp()
Register the given operation as illegal, i.e.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:84
OpTy replaceOpWithNewOp(Operation *op, Args &&... args)
Replaces the result op with a new op that is created without verification.
Definition: PatternMatch.h:741
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&... args)
Add an instance of each of the pattern types &#39;Ts&#39; to the pattern list with the given arguments...
Definition: PatternMatch.h:930
Conversion from types in the Standard dialect to the LLVM IR dialect.
Definition: TypeConverter.h:30
void addLegalOp()
Register the given operations as legal.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:91
This class implements a pattern rewriter for use with ConversionPatterns.
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands...
Definition: Pattern.cpp:309
MLIRContext * getContext() const
Return the MLIRContext used to create this pattern.
Definition: PatternMatch.h:132