MLIR  21.0.0git
X86VectorDialect.cpp
Go to the documentation of this file.
1 //===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the X86Vector dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 #include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
24 
25 #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
26 
27 void x86vector::X86VectorDialect::initialize() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
31  >();
32 }
33 
34 static SmallVector<Value>
36  RewriterBase &rewriter,
37  const LLVMTypeConverter &typeConverter) {
38  SmallVector<Value> operands;
39  auto opType = memrefVal.getType();
40 
41  Type llvmStructType = typeConverter.convertType(opType);
42  Value llvmStruct =
43  rewriter
44  .create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45  .getResult(0);
46  MemRefDescriptor memRefDescriptor(llvmStruct);
47 
48  Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
49  operands.push_back(ptr);
50 
51  return operands;
52 }
53 
54 LogicalResult x86vector::MaskCompressOp::verify() {
55  if (getSrc() && getConstantSrc())
56  return emitError("cannot use both src and constant_src");
57 
58  if (getSrc() && (getSrc().getType() != getDst().getType()))
59  return emitError("failed to verify that src and dst have same type");
60 
61  if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
62  return emitError(
63  "failed to verify that constant_src and dst have same type");
64 
65  return success();
66 }
67 
68 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
69  RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
70  auto loc = getLoc();
71 
72  auto opType = getA().getType();
73  Value src;
74  if (getSrc()) {
75  src = getSrc();
76  } else if (getConstantSrc()) {
77  src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
78  } else {
79  auto zeroAttr = rewriter.getZeroAttr(opType);
80  src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
81  }
82 
83  return SmallVector<Value>{getA(), src, getK()};
84 }
85 
87 x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
88  const LLVMTypeConverter &typeConverter) {
89  SmallVector<Value> operands(getOperands());
90  // Dot product of all elements, broadcasted to all elements.
91  Value scale =
92  rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
93  operands.push_back(scale);
94 
95  return operands;
96 }
97 
98 SmallVector<Value> x86vector::BcstBF16ToPackedF32Op::getIntrinsicOperands(
99  RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
101 }
102 
104 x86vector::CvtPackedOddIndexedBF16ToF32Op::getIntrinsicOperands(
105  RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
106  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
107 }
108 
110 x86vector::CvtPackedEvenIndexedBF16ToF32Op::getIntrinsicOperands(
111  RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
112  return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
113 }
114 
115 #define GET_OP_CLASSES
116 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
static SmallVector< Value > getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType > memrefVal, RewriterBase &rewriter, const LLVMTypeConverter &typeConverter)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:320
IntegerType getI8Type()
Definition: Builders.cpp:59
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
LogicalResult convertType(Type t, SmallVectorImpl< Type > &results) const
Convert the given type.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:66
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:453
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:474
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:424