MLIR  21.0.0git
X86VectorDialect.cpp
Go to the documentation of this file.
1 //===- X86VectorDialect.cpp - MLIR X86Vector ops implementation -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the X86Vector dialect and its operations.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/TypeUtilities.h"
20 
21 using namespace mlir;
22 
23 #include "mlir/Dialect/X86Vector/X86VectorInterfaces.cpp.inc"
24 
25 #include "mlir/Dialect/X86Vector/X86VectorDialect.cpp.inc"
26 
27 void x86vector::X86VectorDialect::initialize() {
28  addOperations<
29 #define GET_OP_LIST
30 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
31  >();
32 }
33 
34 static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
35  const LLVMTypeConverter &typeConverter,
36  RewriterBase &rewriter) {
37  MemRefDescriptor memRefDescriptor(buffer);
38  return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
39 }
40 
41 LogicalResult x86vector::MaskCompressOp::verify() {
42  if (getSrc() && getConstantSrc())
43  return emitError("cannot use both src and constant_src");
44 
45  if (getSrc() && (getSrc().getType() != getDst().getType()))
46  return emitError("failed to verify that src and dst have same type");
47 
48  if (getConstantSrc() && (getConstantSrc()->getType() != getDst().getType()))
49  return emitError(
50  "failed to verify that constant_src and dst have same type");
51 
52  return success();
53 }
54 
55 SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
56  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
57  RewriterBase &rewriter) {
58  auto loc = getLoc();
59  Adaptor adaptor(operands, *this);
60 
61  auto opType = adaptor.getA().getType();
62  Value src;
63  if (adaptor.getSrc()) {
64  src = adaptor.getSrc();
65  } else if (adaptor.getConstantSrc()) {
66  src = rewriter.create<LLVM::ConstantOp>(loc, opType,
67  adaptor.getConstantSrcAttr());
68  } else {
69  auto zeroAttr = rewriter.getZeroAttr(opType);
70  src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
71  }
72 
73  return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
74 }
75 
77 x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
78  const LLVMTypeConverter &typeConverter,
79  RewriterBase &rewriter) {
80  SmallVector<Value> intrinsicOperands(operands);
81  // Dot product of all elements, broadcasted to all elements.
82  Value scale =
83  rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
84  intrinsicOperands.push_back(scale);
85 
86  return intrinsicOperands;
87 }
88 
89 SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
90  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91  RewriterBase &rewriter) {
92  Adaptor adaptor(operands, *this);
93  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
94  typeConverter, rewriter)};
95 }
96 
97 SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
98  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
99  RewriterBase &rewriter) {
100  Adaptor adaptor(operands, *this);
101  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
102  typeConverter, rewriter)};
103 }
104 
105 SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
106  ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
107  RewriterBase &rewriter) {
108  Adaptor adaptor(operands, *this);
109  return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
110  typeConverter, rewriter)};
111 }
112 
113 #define GET_OP_CLASSES
114 #include "mlir/Dialect/X86Vector/X86Vector.cpp.inc"
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer, const LLVMTypeConverter &typeConverter, RewriterBase &rewriter)
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:322
IntegerType getI8Type()
Definition: Builders.cpp:61
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
Helper class to produce LLVM dialect operations extracting or inserting elements of a MemRef descript...
Definition: MemRefBuilder.h:33
Value bufferPtr(OpBuilder &builder, Location loc, const LLVMTypeConverter &converter, MemRefType type)
Builds IR for getting the start address of the buffer represented by this memref: memref....
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:455
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:358
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Include the generated interface declarations.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
Definition: Verifier.cpp:423