MLIR  20.0.0git
VectorPattern.cpp
Go to the documentation of this file.
1 //===- VectorPattern.cpp - Vector conversion pattern to the LLVM dialect --===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 
12 using namespace mlir;
13 
14 // For >1-D vector types, extracts the necessary information to iterate over all
15 // 1-D subvectors in the underlying llrepresentation of the n-D vector
16 // Iterates on the llvm array type until we hit a non-array type (which is
17 // asserted to be an llvm vector type).
20  const LLVMTypeConverter &converter) {
21  assert(vectorType.getRank() > 1 && "expected >1D vector type");
22  NDVectorTypeInfo info;
23  info.llvmNDVectorTy = converter.convertType(vectorType);
25  info.llvmNDVectorTy = nullptr;
26  return info;
27  }
28  info.arraySizes.reserve(vectorType.getRank() - 1);
29  auto llvmTy = info.llvmNDVectorTy;
30  while (isa<LLVM::LLVMArrayType>(llvmTy)) {
31  info.arraySizes.push_back(
32  cast<LLVM::LLVMArrayType>(llvmTy).getNumElements());
33  llvmTy = cast<LLVM::LLVMArrayType>(llvmTy).getElementType();
34  }
35  if (!LLVM::isCompatibleVectorType(llvmTy))
36  return info;
37  info.llvm1DVectorTy = llvmTy;
38  return info;
39 }
40 
41 // Express `linearIndex` in terms of coordinates of `basis`.
42 // Returns the empty vector when linearIndex is out of the range [0, P] where
43 // P is the product of all the basis coordinates.
44 //
45 // Prerequisites:
46 // Basis is an array of nonnegative integers (signed type inherited from
47 // vector shape type).
49  unsigned linearIndex) {
51  res.reserve(basis.size());
52  for (unsigned basisElement : llvm::reverse(basis)) {
53  res.push_back(linearIndex % basisElement);
54  linearIndex = linearIndex / basisElement;
55  }
56  if (linearIndex > 0)
57  return {};
58  std::reverse(res.begin(), res.end());
59  return res;
60 }
61 
62 // Iterate of linear index, convert to coords space and insert splatted 1-D
63 // vector in each position.
65  OpBuilder &builder,
66  function_ref<void(ArrayRef<int64_t>)> fun) {
67  unsigned ub = 1;
68  for (auto s : info.arraySizes)
69  ub *= s;
70  for (unsigned linearIndex = 0; linearIndex < ub; ++linearIndex) {
71  auto coords = getCoordinates(info.arraySizes, linearIndex);
72  // Linear index is out of bounds, we are done.
73  if (coords.empty())
74  break;
75  assert(coords.size() == info.arraySizes.size());
76  fun(coords);
77  }
78 }
79 
81  Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter,
82  std::function<Value(Type, ValueRange)> createOperand,
83  ConversionPatternRewriter &rewriter) {
84  auto resultNDVectorType = cast<VectorType>(op->getResult(0).getType());
85  auto resultTypeInfo =
86  extractNDVectorTypeInfo(resultNDVectorType, typeConverter);
87  auto result1DVectorTy = resultTypeInfo.llvm1DVectorTy;
88  auto resultNDVectoryTy = resultTypeInfo.llvmNDVectorTy;
89  auto loc = op->getLoc();
90  Value desc = rewriter.create<LLVM::UndefOp>(loc, resultNDVectoryTy);
91  nDVectorIterate(resultTypeInfo, rewriter, [&](ArrayRef<int64_t> position) {
92  // For this unrolled `position` corresponding to the `linearIndex`^th
93  // element, extract operand vectors
94  SmallVector<Value, 4> extractedOperands;
95  for (const auto &operand : llvm::enumerate(operands)) {
96  extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
97  loc, operand.value(), position));
98  }
99  Value newVal = createOperand(result1DVectorTy, extractedOperands);
100  desc = rewriter.create<LLVM::InsertValueOp>(loc, desc, newVal, position);
101  });
102  rewriter.replaceOp(op, desc);
103  return success();
104 }
105 
107  Operation *op, StringRef targetOp, ValueRange operands,
108  ArrayRef<NamedAttribute> targetAttrs,
109  const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter,
110  IntegerOverflowFlags overflowFlags) {
111  assert(!operands.empty());
112 
113  // Cannot convert ops if their operands are not of LLVM type.
114  if (!llvm::all_of(operands.getTypes(), isCompatibleType))
115  return failure();
116 
117  auto llvmNDVectorTy = operands[0].getType();
118  if (!isa<LLVM::LLVMArrayType>(llvmNDVectorTy))
119  return oneToOneRewrite(op, targetOp, operands, targetAttrs, typeConverter,
120  rewriter, overflowFlags);
121 
122  auto callback = [op, targetOp, targetAttrs, overflowFlags,
123  &rewriter](Type llvm1DVectorTy, ValueRange operands) {
124  Operation *newOp =
125  rewriter.create(op->getLoc(), rewriter.getStringAttr(targetOp),
126  operands, llvm1DVectorTy, targetAttrs);
127  LLVM::detail::setNativeProperties(newOp, overflowFlags);
128  return newOp->getResult(0);
129  };
130 
131  return handleMultidimensionalVectors(op, operands, typeConverter, callback,
132  rewriter);
133 }
static int64_t getNumElements(Type t)
Compute the total number of elements in the given type, also taking into account nested types.
static Value handleMultidimensionalVectors(ImplicitLocOpBuilder &builder, ValueRange operands, int64_t vectorWidth, llvm::function_ref< Value(ValueRange)> compute)
StringAttr getStringAttr(const Twine &bytes)
Definition: Builders.cpp:302
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class helps build Operations.
Definition: Builders.h:216
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:497
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
type_range getType() const
type_range getTypes() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
LogicalResult handleMultidimensionalVectors(Operation *op, ValueRange operands, const LLVMTypeConverter &typeConverter, std::function< Value(Type, ValueRange)> createOperand, ConversionPatternRewriter &rewriter)
LogicalResult oneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
Replaces the given operation "op" with a new operation of type "targetOp" and given operands.
Definition: Pattern.cpp:338
LogicalResult vectorOneToOneRewrite(Operation *op, StringRef targetOp, ValueRange operands, ArrayRef< NamedAttribute > targetAttrs, const LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter, IntegerOverflowFlags overflowFlags=IntegerOverflowFlags::none)
void setNativeProperties(Operation *op, IntegerOverflowFlags overflowFlags)
Handle generically setting flags as native properties on LLVM operations.
Definition: Pattern.cpp:330
void nDVectorIterate(const NDVectorTypeInfo &info, OpBuilder &builder, function_ref< void(ArrayRef< int64_t >)> fun)
NDVectorTypeInfo extractNDVectorTypeInfo(VectorType vectorType, const LLVMTypeConverter &converter)
SmallVector< int64_t, 4 > getCoordinates(ArrayRef< int64_t > basis, unsigned linearIndex)
bool isCompatibleVectorType(Type type)
Returns true if the given type is a vector type compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:878
bool isCompatibleType(Type type)
Returns true if the given type is compatible with the LLVM dialect.
Definition: LLVMTypes.cpp:860
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
Include the generated interface declarations.
SmallVector< int64_t, 4 > arraySizes
Definition: VectorPattern.h:27