MLIR 22.0.0git
VectorPattern.h
Go to the documentation of this file.
1//===- VectorPattern.h - Conversion pattern to the LLVM dialect -*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#ifndef MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
10#define MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
11
14
15namespace mlir {
16
17namespace LLVM {
18namespace detail {
19// Helper struct to "unroll" operations on n-D vectors in terms of operations on
20// 1-D LLVM vectors.
22 // LLVM array struct which encodes n-D vectors.
24 // LLVM vector type which encodes the inner 1-D vector type.
26 // Multiplicity of llvmNDVectorTy to llvm1DVectorTy.
28};
29
30// For >1-D vector types, extracts the necessary information to iterate over all
31// 1-D subvectors in the underlying llrepresentation of the n-D vector
32// Iterates on the llvm array type until we hit a non-array type (which is
33// asserted to be an llvm vector type).
34NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType,
35 const LLVMTypeConverter &converter);
36
37// Express `linearIndex` in terms of coordinates of `basis`.
38// Returns the empty vector when linearIndex is out of the range [0, P] where
39// P is the product of all the basis coordinates.
40//
41// Prerequisites:
42// Basis is an array of nonnegative integers (signed type inherited from
43// vector shape type).
45 unsigned linearIndex);
46
47// Iterate of linear index, convert to coords space and insert splatted 1-D
48// vector in each position.
49void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder,
51
53 Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
54 std::function<Value(Type, ValueRange)> createOperand,
55 ConversionPatternRewriter &rewriter);
56
57LogicalResult vectorOneToOneRewrite(
58 Operation *op, StringRef targetOp, ValueRange operands,
59 ArrayRef<NamedAttribute> targetAttrs,
60 const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
61 IntegerOverflowFlags overflowFlags = IntegerOverflowFlags::none);
62} // namespace detail
63} // namespace LLVM
64
65// Default attribute conversion class, which passes all source attributes
66// through to the target op, unmodified.
67template <typename SourceOp, typename TargetOp>
69public:
70 AttrConvertPassThrough(SourceOp srcOp) : srcAttrs(srcOp->getAttrs()) {}
71
72 ArrayRef<NamedAttribute> getAttrs() const { return srcAttrs; }
73 LLVM::IntegerOverflowFlags getOverflowFlags() const {
74 return LLVM::IntegerOverflowFlags::none;
75 }
76
77private:
79};
80
81/// Basic lowering implementation to rewrite Ops with just one result to the
82/// LLVM Dialect. This supports higher-dimensional vector types.
83/// The AttrConvert template template parameter should be a template class
84/// with SourceOp and TargetOp type parameters, a constructor that takes
85/// a SourceOp instance, and a getAttrs() method that returns
86/// ArrayRef<NamedAttribute>.
87template <typename SourceOp, typename TargetOp,
88 template <typename, typename> typename AttrConvert =
89 AttrConvertPassThrough>
91public:
94
95 /// Return the given type if it's a floating point type. If the given type is
96 /// a vector type, return its element type if it's a floating point type.
97 static FloatType getFloatingPointType(Type type) {
98 if (auto floatType = dyn_cast<FloatType>(type))
99 return floatType;
100 if (auto vecType = dyn_cast<VectorType>(type))
101 return dyn_cast<FloatType>(vecType.getElementType());
102 return nullptr;
103 }
104
105 LogicalResult
106 matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
107 ConversionPatternRewriter &rewriter) const override {
108 static_assert(
109 std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
110 "expected single result op");
111
112 // The pattern should not apply if a floating-point operand is converted to
113 // a non-floating-point type. This indicates that the floating point type
114 // is not supported by the LLVM lowering. (Such types are converted to
115 // integers.)
116 auto checkType = [&](Value v) -> LogicalResult {
117 FloatType floatType = getFloatingPointType(v.getType());
118 if (!floatType)
119 return success();
120 Type convertedType = this->getTypeConverter()->convertType(floatType);
121 if (!isa_and_nonnull<FloatType>(convertedType))
122 return rewriter.notifyMatchFailure(op,
123 "unsupported floating point type");
124 return success();
125 };
126 for (Value operand : op->getOperands())
127 if (failed(checkType(operand)))
128 return failure();
129 if (failed(checkType(op->getResult(0))))
130 return failure();
131
132 // Determine attributes for the target op
133 AttrConvert<SourceOp, TargetOp> attrConvert(op);
134
136 op, TargetOp::getOperationName(), adaptor.getOperands(),
137 attrConvert.getAttrs(), *this->getTypeConverter(), rewriter,
138 attrConvert.getOverflowFlags());
139 }
140};
141} // namespace mlir
142
143#endif // MLIR_CONVERSION_LLVMCOMMON_VECTORPATTERN_H
return success()
AttrConvertPassThrough(SourceOp srcOp)
LLVM::IntegerOverflowFlags getOverflowFlags() const
ArrayRef< NamedAttribute > getAttrs() const
ConvertOpToLLVMPattern(const LLVMTypeConverter &typeConverter, PatternBenefit benefit=1)
Definition Pattern.h:215
const LLVMTypeConverter * getTypeConverter() const
Definition Pattern.cpp:27
Conversion from types to the LLVM IR dialect.
This class helps build Operations.
Definition Builders.h:207
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Basic lowering implementation to rewrite Ops with just one result to the LLVM Dialect.
VectorConvertToLLVMPattern< SourceOp, TargetOp > Super
static FloatType getFloatingPointType(Type type)
Return the given type if it's a floating point type.
LogicalResult matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override
Methods that operate on the SourceOp type.
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
Include the generated interface declarations.
llvm::function_ref< Fn > function_ref
Definition LLVM.h:152
SmallVector< int64_t, 4 > arraySizes